Skip to content

Commit 8cfd9b8

Browse files
authored
[MLIR] Make generic skip packing init operand when not used in DataLayoutPropagation (#146139)
In both `bubbleUpPackOpThroughGenericOp()` or `pushDownUnPackOpThroughGenericOp()`, we can simplify the lowered IR by removing the pack of an empty when the init tensor isn't used in generic op. Instead of packing an empty tensor, the empty tensor can be forwarded to the generic output. This allows cleaner result after data layout propagation.
1 parent 08cf6ae commit 8cfd9b8

File tree

2 files changed

+62
-22
lines changed

2 files changed

+62
-22
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,12 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
358358
return newGenericOp;
359359
}
360360

361+
static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
362+
return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
363+
return genericOp.getMatchingBlockArgument(&operand).use_empty();
364+
});
365+
}
366+
361367
/// Bubbles up linalg.pack op through a producer generic op. This
362368
/// swap pack(generic) to generic(pack). The new generic op works on packed
363369
/// domain; pack ops are created for input and output operands. E.g.,
@@ -470,12 +476,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
470476
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
471477
genericOp, opOperand);
472478

473-
// If the dps init operand of the generic is a tensor.empty forward the pack
474-
// op destination.
479+
// Forward the new tensor.empty as a destination if it is one of the following
480+
// situations:
481+
// 1) The dps init operand is a tensor.empty.
482+
// 2) The dps init is a write-only operand, i.e., it is not used in the
483+
// genericOp
475484
Value dest = packedOutOperand;
476-
if (auto initTensor = genericOp.getDpsInitOperand(0)
477-
->get()
478-
.getDefiningOp<tensor::EmptyOp>()) {
485+
auto initTensor =
486+
genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
487+
if (initTensor || isGenericOutsNotUsed(genericOp)) {
479488
dest = packOpDest;
480489
}
481490
// pack(unpack) isn't naively foldable because the unpack op can be from
@@ -1101,12 +1110,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11011110
genericOp, genericOp.getDpsInitOperand(0));
11021111
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
11031112

1104-
// If the dps init operand of the generic is a tensor.empty, do not pack it
1105-
// and forward the new tensor.empty as a destination.
1113+
// Forward the new tensor.empty as a destination if it is one of the following
1114+
// situations:
1115+
// 1) The dps init operand is a tensor.empty.
1116+
// 2) The dps init is a write-only operand, i.e., it is not used in the
1117+
// genericOp
11061118
Value dest = packedOutOperand;
1107-
if (auto initTensor = genericOp.getDpsInitOperand(0)
1108-
->get()
1109-
.getDefiningOp<tensor::EmptyOp>()) {
1119+
auto initTensor =
1120+
genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
1121+
if (initTensor || isGenericOutsNotUsed(genericOp)) {
11101122
if (destPack)
11111123
dest = destPack.getDest();
11121124
}

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
435435

436436
// -----
437437

438+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
439+
func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
440+
%elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
441+
ins(%arg0 : tensor<128x256xi32>)
442+
outs(%init : tensor<128x256xi32>) {
443+
^bb0(%arg3: i32, %arg4: i32):
444+
%4 = arith.addi %arg3, %arg3 : i32
445+
linalg.yield %4 : i32
446+
} -> tensor<128x256xi32>
447+
%empty = tensor.empty() : tensor<16x4x32x16xi32>
448+
%pack = linalg.pack %elem
449+
outer_dims_perm = [1, 0]
450+
inner_dims_pos = [0, 1]
451+
inner_tiles = [32, 16]
452+
into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
453+
return %pack : tensor<16x4x32x16xi32>
454+
}
455+
456+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
457+
// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
458+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
459+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
460+
// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
461+
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
462+
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]]
463+
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
464+
// CHECK-SAME: into %[[ARG0_EMPTY]]
465+
// CHECK: %[[RES:.+]] = linalg.generic
466+
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
467+
// CHECK-SAME: ins(%[[PACKED_ARG0]]
468+
// CHECK-SAME: outs(%[[ARG1_EMPTY]]
469+
470+
// -----
471+
438472
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
439473

440474
func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
@@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
497531

498532
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
499533

500-
func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
534+
func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
501535
%0 = tensor.empty() : tensor<12x56x56x64xf32>
502536
%1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
503537
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
@@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
509543
}
510544

511545
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
512-
// CHECK-LABEL: func.func @unpack_element_type_change
546+
// CHECK-LABEL: func.func @unpack_element_type_change_no_use
513547
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
514548
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
515-
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
516-
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
517-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
518-
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
549+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
519550
// CHECK: %[[RES:.+]] = linalg.generic
520551
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
521552
// CHECK-SAME: ins(%[[ARG0]]
522-
// CHECK-SAME: outs(%[[ARG1_PACK]]
553+
// CHECK-SAME: outs(%[[EMPTY]]
523554
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
524555
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
525556
// CHECK-SAME: into %[[ARG1]]
@@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
14021433
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
14031434
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
14041435
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
1405-
// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
1406-
// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
1407-
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
1408-
// CHECK-SAME: into %[[ARG2_PACK_EMPTY]]
1436+
// CHECK: %[[EMPTY:.+]] = tensor.empty
14091437
// CHECK: %[[GENERIC:.+]] = linalg.generic
14101438
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
1411-
// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
1439+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
14121440
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
14131441
// CHECK-SAME: into %[[ARG2]]
14141442
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>

0 commit comments

Comments
 (0)