Skip to content

Commit 24d9bdd

Browse files
committed
[mlir][vector] Determine vector sizes from the result shape in the case of tensor pack
When the vector sizes are not passed as inputs to the vector transform operation, the vector sizes are queried from the static result shape in the case of tensor.pack op.
1 parent e47fd09 commit 24d9bdd

File tree

2 files changed

+77
-4
lines changed

2 files changed

+77
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,17 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15251525
(void)status; // prevent unused variable warning on non-assert builds.
15261526
assert(succeeded(status) && "failed to reify result shapes");
15271527

1528+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1529+
1530+
// If the input vector sizes are not provided, then the vector sizes are
1531+
// determined by the result tensor shape.
1532+
if (inputVectorSizes.empty()) {
1533+
// Make sure that the result tensor shape is static.
1534+
if (ShapedType::isDynamicShape(resultTensorShape))
1535+
return failure();
1536+
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1537+
}
1538+
15281539
// Create masked TransferReadOp.
15291540
SmallVector<int64_t> inputShape(inputVectorSizes);
15301541
auto innerTiles = packOp.getStaticInnerTiles();
@@ -1763,7 +1774,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
17631774
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
17641775
/// given `shape`, i.e., it meets:
17651776
/// 1. The numbers of elements in both array are equal.
1766-
/// 2. `inputVectorSizes` does nos have dynamic dimensions.
1777+
/// 2. `inputVectorSizes` does not have dynamic dimensions.
17671778
/// 3. All the values in `inputVectorSizes` are greater than or equal to
17681779
/// static sizes in `shape`.
17691780
static LogicalResult
@@ -1881,18 +1892,19 @@ static LogicalResult vectorizeLinalgOpPrecondition(
18811892
return success();
18821893
}
18831894

1884-
/// TODO: Use a matcher to check for a constant padding value.
18851895
static LogicalResult
18861896
vectorizePackOpPrecondition(tensor::PackOp packOp,
18871897
ArrayRef<int64_t> inputVectorSizes) {
18881898
auto padValue = packOp.getPaddingValue();
1889-
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
1899+
Attribute cstAttr;
1900+
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
18901901
LDBG("pad value is not constant: " << packOp << "\n");
18911902
return failure();
18921903
}
18931904

18941905
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1895-
if (failed(isValidMaskedInputVector(
1906+
if (!inputVectorSizes.empty() &&
1907+
failed(isValidMaskedInputVector(
18961908
resultTensorShape.take_front(packOp.getSourceRank()),
18971909
inputVectorSizes)))
18981910
return failure();

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,64 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
812812
transform.yield
813813
}
814814
}
815+
816+
// -----
817+
818+
// CHECK-LABEL: test_vectorize_pack_no_vector_sizes
819+
func.func @test_vectorize_pack_no_vector_sizes(%arg0: tensor<64x4xf32>, %arg1: tensor<2x4x16x2xf32>) -> tensor<2x4x16x2xf32> {
820+
%pack = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1 : tensor<64x4xf32> -> tensor<2x4x16x2xf32>
821+
return %pack : tensor<2x4x16x2xf32>
822+
}
823+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
824+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
825+
// CHECK: %[[read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[cst]]
826+
// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32>
827+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[read]] : vector<64x4xf32> to vector<4x16x2x2xf32>
828+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [2, 0, 1, 3] : vector<4x16x2x2xf32> to vector<2x4x16x2xf32>
829+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
830+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4x16x2xf32>
831+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
832+
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<2x4x16x2xf32>, tensor<2x4x16x2xf32>
833+
// CHECK: return %[[write]] : tensor<2x4x16x2xf32>
834+
835+
module attributes {transform.with_named_sequence} {
836+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
837+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
838+
transform.structured.vectorize %0 : !transform.any_op
839+
transform.yield
840+
}
841+
}
842+
843+
// -----
844+
845+
// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
846+
func.func @test_vectorize_padded_pack_no_vector_sizes(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
847+
%pad = arith.constant 0.000000e+00 : f32
848+
%pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
849+
return %pack : tensor<32x4x1x16x2xf32>
850+
}
851+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
852+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
853+
// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
854+
// CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
855+
// CHECK-DAG: %[[c15:.*]] = arith.constant 15 : index
856+
// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c7]], %[[c15]] : vector<32x8x16xi1>
857+
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
858+
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
859+
// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
860+
// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
861+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
862+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
863+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
864+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
865+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
866+
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
867+
// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
868+
869+
module attributes {transform.with_named_sequence} {
870+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
871+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
872+
transform.structured.vectorize %0 : !transform.any_op
873+
transform.yield
874+
}
875+
}

0 commit comments

Comments
 (0)