Skip to content

Commit b9d6cbd

Browse files
authored
[MLIR] Folding unpack and pack sequence in data layout propagation from padded domain (#138332)
In `DataLayoutPropagation` patterns, it can populate sequence of unpack op followed by pack op. Such sequence tend to disrupt tiling and can be optimized. This is especially true for pack and unpack in padded values. The idea of this patch is to optimize the propagation by never creating the unpack + pack in cases where the padding value does not matter for the op that is being propagated through. We can optimize the unpack/pack pair away from in particular `PushDownUnPackOpThroughGenericOp` pattern. If the operand of the generic op happen to come from an unpack, there's no need to create new packs of the generic operand. We can fold the unpack -> pack sequence and use the operand from the original source of the unpack op.
1 parent 9602216 commit b9d6cbd

File tree

2 files changed

+91
-58
lines changed

2 files changed

+91
-58
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,20 +298,42 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
298298
return std::make_tuple(packedOperand, indexingMap);
299299
}
300300

301-
/// Pack a genericOp and return it.
301+
/// This function is a helper subroutine to pack a genericOp and return it. It
302+
/// will create a new generic op with the packed operand and the packed output
303+
/// according to packInfo when we attempt to push down unpack or bubble up pack
304+
/// around it. Implicitly this will only work when a packInfo can be obtained.
305+
/// This make sure that we are only using this function on parallel permuted
306+
/// dimensions.
302307
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
303308
Value dest, AffineMap packedOutIndexingMap,
304-
const PackInfo &packInfo) {
309+
const PackInfo &packInfo,
310+
bool isFoldableUnpackPack) {
305311
Location loc = genericOp.getLoc();
306312
SmallVector<Value> inputOperands;
313+
SmallVector<Value> inputOperandsFromUnpackedSource;
307314
SmallVector<AffineMap> indexingMaps;
308315
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
309316
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
310317
rewriter, loc, packInfo, genericOp, inputOperand);
318+
if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
319+
inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
320+
} else {
321+
inputOperandsFromUnpackedSource.push_back(packedOperand);
322+
}
311323
inputOperands.push_back(packedOperand);
312324
indexingMaps.push_back(packedIndexingMap);
313325
}
314326

327+
// If the pack and unpack op can be folded:
328+
// 1) use unpack op source op for operand to fold unpack -> pack sequence.
329+
// 2) init tensor of the generic op can be replaced by the destination of the
330+
// pack op.
331+
if (isFoldableUnpackPack) {
332+
inputOperands = inputOperandsFromUnpackedSource;
333+
if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
334+
dest = destPack.getDest();
335+
}
336+
315337
int64_t numInnerLoops = packInfo.getNumTiledLoops();
316338
SmallVector<utils::IteratorType> iterTypes =
317339
genericOp.getIteratorTypesArray();
@@ -447,8 +469,10 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
447469
.getDefiningOp<tensor::EmptyOp>()) {
448470
dest = packOpDest;
449471
}
472+
// pack(unpack) isn't naively foldable because the unpack op can be from
473+
// an arbitrary domain so we need to keep both.
450474
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
451-
*packInfo);
475+
*packInfo, /*isFoldableUnpackPack=*/false);
452476
}
453477

454478
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
@@ -1085,8 +1109,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
10851109
}
10861110

10871111
// Pack the genericOp.
1112+
// pack(unpack) is foldable in this case. This is because in pushing down the
1113+
// unpack, by default we will populate an additional pack op after the unpack.
1114+
// This guarantees them to be foldable.
10881115
GenericOp newGenericOp =
1089-
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1116+
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1117+
/*isFoldableUnpackPack=*/true);
10901118
Value newResult =
10911119
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
10921120

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -455,13 +455,10 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
455455
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
456456
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
457457
// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
458-
// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
459-
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
460-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
461-
// CHECK-SAME: into %[[ARG0_EMPTY_PACK]]
458+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
462459
// CHECK: %[[RES:.+]] = linalg.generic
463460
// CHECK-SAME: indexing_maps = [#[[$MAP]]]
464-
// CHECK-SAME: outs(%[[PACKED_ARG0]]
461+
// CHECK-SAME: outs(%[[EMPTY]]
465462
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
466463
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
467464
// CHECK-SAME: into %[[UNPACKED_ARG0]]
@@ -485,22 +482,11 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
485482
// CHECK-LABEL: func.func @unpack_on_input
486483
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
487484
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
488-
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
489-
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
490-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
491-
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
492-
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
493-
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
494-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
495-
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
496-
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
497-
// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
498-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
499-
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
485+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
500486
// CHECK: %[[RES:.+]] = linalg.generic
501487
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
502-
// CHECK-SAME: ins(%[[ARG0_PACK]]
503-
// CHECK-SAME: outs(%[[ARG1_PACK]]
488+
// CHECK-SAME: ins(%[[ARG0]]
489+
// CHECK-SAME: outs(%[[EMPTY]]
504490
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
505491
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
506492
// CHECK-SAME: into %[[ARG1]]
@@ -524,22 +510,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
524510
// CHECK-LABEL: func.func @unpack_element_type_change
525511
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
526512
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
527-
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
528-
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
529-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
530-
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
531-
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
532-
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
533-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
534-
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
535-
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
536-
// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
537-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
538-
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
513+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
539514
// CHECK: %[[RES:.+]] = linalg.generic
540515
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
541-
// CHECK-SAME: ins(%[[ARG0_PACK]]
542-
// CHECK-SAME: outs(%[[ARG1_PACK]]
516+
// CHECK-SAME: ins(%[[ARG0]]
517+
// CHECK-SAME: outs(%[[EMPTY]]
543518
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
544519
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
545520
// CHECK-SAME: into %[[ARG1]]
@@ -564,19 +539,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
564539
// CHECK-LABEL: func.func @forward_tensor_empty
565540
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
566541
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
567-
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
568-
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
569-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
570-
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
571-
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
572-
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
573-
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
574-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
575-
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
542+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
576543
// CHECK: %[[RES:.+]] = linalg.generic
577544
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
578-
// CHECK-SAME: ins(%[[PACKED_ARG0]]
579-
// CHECK-SAME: outs(%[[DEST]]
545+
// CHECK-SAME: ins(%[[ARG0]]
546+
// CHECK-SAME: outs(%[[EMPTY]]
580547
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
581548
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
582549
// CHECK-SAME: into %[[FINAL_RES]]
@@ -810,12 +777,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
810777
}
811778

812779
// CHECK-LABEL: func.func @unpack_empty_inner_dims
813-
// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
814-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
815-
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
816-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
780+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
817781
// CHECK: %[[RES:.+]] = linalg.generic
818-
// CHECK-SAME: ins(%[[PACKED_ARG0]]
782+
// CHECK-SAME: ins(%[[ARG0]]
819783
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
820784
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
821785

@@ -943,14 +907,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
943907
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
944908
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
945909
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
946-
// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
947-
// CHECK: %[[PACK_ARG0:.+]] = linalg.pack
948-
// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
949-
// CHECK-SAME: into %[[PACK_EMPTY]]
950910
// CHECK: %[[POOL:.+]] = linalg.generic
951911
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
952912
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
953-
// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
913+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
954914
// CHECK-SAME: outs(%[[INIT]]
955915
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
956916
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1381,48 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
14211381
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
14221382
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
14231383
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
1384+
1385+
// -----
1386+
1387+
func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %dest: tensor<?x64xf32>, %arg1: tensor<?x64xbf16>) -> tensor<?x64xbf16> {
1388+
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
1389+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xbf16>) {
1390+
^bb0(%in: f32, %out: bf16):
1391+
%1 = arith.truncf %in : f32 to bf16
1392+
linalg.yield %1 : bf16
1393+
} -> tensor<?x64xbf16>
1394+
return %0 : tensor<?x64xbf16>
1395+
}
1396+
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_foldable
1397+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1398+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1399+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1400+
// CHECK: %[[EMPTY:.+]] = tensor.empty
1401+
// CHECK: %[[GENERIC:.+]] = linalg.generic
1402+
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1403+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
1404+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
1405+
// CHECK-SAME: into %[[ARG2]]
1406+
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
1407+
1408+
// -----
1409+
1410+
func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %arg1: tensor<?x64xf32>) -> tensor<?x64xf32> {
1411+
%unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %arg1 : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
1412+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xf32>) {
1413+
^bb0(%in: f32, %out: f32):
1414+
%1 = arith.addf %in, %out : f32
1415+
linalg.yield %1 : f32
1416+
} -> tensor<?x64xf32>
1417+
return %0 : tensor<?x64xf32>
1418+
}
1419+
// CHECK-LABEL: func.func @push_unpack_in_padded_domain_out_used
1420+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1421+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1422+
// CHECK: %[[EMPTY:.+]] = tensor.empty
1423+
// CHECK: %[[GENERIC:.+]] = linalg.generic
1424+
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1425+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xf32>)
1426+
// CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
1427+
// CHECK-SAME: into %[[ARG1]]
1428+
// CHECK: return %[[UNPACK2]] : tensor<?x64xf32>

0 commit comments

Comments
 (0)