Skip to content

Commit d21beb5

Browse files
committed
[MLIR][Linalg] Avoid padding attribute in pack when possible
If we deal with statically known tensors and tiles and a given tile perfectly divides a given dimension, we can omit the padding attribute. As a bonus point, we can now run pack and unpack propagation (currently, we bail out during propagation if we have the padding attribute). Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D154607
1 parent 2f60813 commit d21beb5

File tree

4 files changed

+89
-7
lines changed

4 files changed

+89
-7
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1823,7 +1823,9 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
18231823

18241824
// Returns true if we have enough static information to catch undefined
18251825
// behavior when the tile size does not divide perfectly the dimension of
1826-
// the input tensor.
1826+
// the input tensor. If a given dimension or a tile associated with it is
1827+
// dynamic, the dimension is not considered as we don't have enough static
1828+
// information to understand if the tile perfectly divides that dimension.
18271829
static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
18281830
ArrayRef<int64_t> innerDimsPos,
18291831
ArrayRef<OpFoldResult> innerTiles);

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,12 +563,25 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
563563
Value dest = tensor::PackOp::createDestinationTensor(
564564
rewriter, loc, operand, innerPackSizes, innerPos,
565565
/*outerDimsPerm=*/{});
566-
// TODO: value of the padding attribute should be determined by consumers.
567-
auto zeroAttr =
568-
rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
569-
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
570-
packOps.push_back(rewriter.create<tensor::PackOp>(
571-
loc, operand, dest, innerPos, innerPackSizes, zero));
566+
ShapedType operandType = operand.getType().cast<ShapedType>();
567+
bool areConstantTiles =
568+
llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
569+
return getConstantIntValue(tile).has_value();
570+
});
571+
if (areConstantTiles && operandType.hasStaticShape() &&
572+
!tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos,
573+
innerPackSizes)) {
574+
packOps.push_back(rewriter.create<tensor::PackOp>(
575+
loc, operand, dest, innerPos, innerPackSizes));
576+
} else {
577+
// TODO: value of the padding attribute should be determined by
578+
// consumers.
579+
auto zeroAttr =
580+
rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
581+
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
582+
packOps.push_back(rewriter.create<tensor::PackOp>(
583+
loc, operand, dest, innerPos, innerPackSizes, zero));
584+
}
572585
inputsAndInits.push_back(packOps.back());
573586
}
574587
}

mlir/test/Dialect/Linalg/transform-op-pack.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,3 +593,36 @@ transform.sequence failures(propagate) {
593593
: (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
594594
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
595595
}
596+
597+
// -----
598+
599+
func.func @no_padding_on_packs(%A: tensor<32x32xf32>, %B: tensor<32x32xf32>, %C: tensor<32x32xf32>)
600+
-> tensor<32x32xf32> {
601+
%0 = linalg.matmul ins(%A, %B: tensor<32x32xf32>, tensor<32x32xf32>)
602+
outs(%C: tensor<32x32xf32>)
603+
-> tensor<32x32xf32>
604+
return %0 : tensor<32x32xf32>
605+
}
606+
607+
// CHECK-LABEL: no_padding_on_packs
608+
// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8]
609+
// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32>
610+
// CHECK: tensor.pack %{{.+}} outer_dims_perm = [1, 0]
611+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 8]
612+
// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<4x4x8x8xf32>
613+
// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8]
614+
// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32>
615+
616+
transform.sequence failures(propagate) {
617+
^bb0(%arg1: !transform.any_op):
618+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
619+
%1 = transform.structured.pack %0 packed_sizes = [4, 8, 8]
620+
: (!transform.any_op) -> (!transform.op<"linalg.generic">)
621+
%pack = transform.get_producer_of_operand %1[1]
622+
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
623+
%2, %pack_2, %empty_unpack_2 =
624+
transform.structured.pack_transpose %pack with_compute_op(%1)
625+
outer_perm = [1, 0] inner_perm = [1, 0]
626+
: (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">)
627+
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.any_op)
628+
}

mlir/test/Dialect/Linalg/transform-pack-greedily.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,37 @@ transform.sequence failures(propagate) {
348348
matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
349349
: (!transform.op<"linalg.matvec">) -> !transform.any_op
350350
}
351+
352+
// -----
353+
354+
func.func @no_padding_on_packs(%A: tensor<32x32xf32>, %B: tensor<32x32xf32>, %C: tensor<32x32xf32>)
355+
-> tensor<32x32xf32> {
356+
%0 = linalg.matmul ins(%A, %B: tensor<32x32xf32>, tensor<32x32xf32>)
357+
outs(%C: tensor<32x32xf32>)
358+
-> tensor<32x32xf32>
359+
return %0 : tensor<32x32xf32>
360+
}
361+
362+
// CHECK-LABEL: no_padding_on_packs
363+
// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 4]
364+
// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<4x8x8x4xf32>
365+
// CHECK: tensor.pack %{{.+}} outer_dims_perm = [1, 0]
366+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 16] into %{{.+}} : tensor<32x32xf32> -> tensor<2x8x4x16xf32>
367+
// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 16]
368+
// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<4x2x8x16xf32>
369+
370+
transform.sequence failures(propagate) {
371+
^bb0(%arg1: !transform.any_op):
372+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
373+
: (!transform.any_op) -> !transform.op<"linalg.matmul">
374+
%1 = transform.structured.pack_greedily %0
375+
matmul_packed_sizes = [8, 16, 4] matmul_inner_dims_order = [0, 1, 2]
376+
: (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic">
377+
%pack = transform.get_producer_of_operand %1[1]
378+
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
379+
%2, %pack_2, %empty_unpack_2 =
380+
transform.structured.pack_transpose %pack with_compute_op(%1)
381+
outer_perm = [1, 0] inner_perm = [1, 0]
382+
: (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">)
383+
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.any_op)
384+
}

0 commit comments

Comments
 (0)