Skip to content

Commit cf0d060

Browse files
committed
[mlir][vector] Determine vector sizes from the result shape in the case of tensor pack
When the vector sizes are not passed as inputs to the vector transform operation, the vector sizes are queried from the static result shape in the case of tensor.pack op.
1 parent e47fd09 commit cf0d060

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,17 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15251525
(void)status; // prevent unused variable warning on non-assert builds.
15261526
assert(succeeded(status) && "failed to reify result shapes");
15271527

1528+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1529+
1530+
// If the input vector sizes are not provided, then the vector sizes are
1531+
// determined by the result tensor shape.
1532+
if (inputVectorSizes.empty()) {
1533+
// Make sure that the result tensor shape is static.
1534+
if (ShapedType::isDynamicShape(resultTensorShape))
1535+
return failure();
1536+
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1537+
}
1538+
15281539
// Create masked TransferReadOp.
15291540
SmallVector<int64_t> inputShape(inputVectorSizes);
15301541
auto innerTiles = packOp.getStaticInnerTiles();
@@ -1763,7 +1774,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
17631774
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
17641775
/// given `shape`, i.e., it meets:
17651776
/// 1. The numbers of elements in both array are equal.
1766-
/// 2. `inputVectorSizes` does nos have dynamic dimensions.
1777+
/// 2. `inputVectorSizes` does not have dynamic dimensions.
17671778
/// 3. All the values in `inputVectorSizes` are greater than or equal to
17681779
/// static sizes in `shape`.
17691780
static LogicalResult
@@ -1881,18 +1892,19 @@ static LogicalResult vectorizeLinalgOpPrecondition(
18811892
return success();
18821893
}
18831894

1884-
/// TODO: Use a matcher to check for a constant padding value.
18851895
static LogicalResult
18861896
vectorizePackOpPrecondition(tensor::PackOp packOp,
18871897
ArrayRef<int64_t> inputVectorSizes) {
18881898
auto padValue = packOp.getPaddingValue();
1889-
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
1899+
Attribute cstAttr;
1900+
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
18901901
LDBG("pad value is not constant: " << packOp << "\n");
18911902
return failure();
18921903
}
18931904

18941905
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1895-
if (failed(isValidMaskedInputVector(
1906+
if (!inputVectorSizes.empty() &&
1907+
failed(isValidMaskedInputVector(
18961908
resultTensorShape.take_front(packOp.getSourceRank()),
18971909
inputVectorSizes)))
18981910
return failure();

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,37 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
812812
transform.yield
813813
}
814814
}
815+
816+
// -----
817+
818+
// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
819+
func.func @test_vectorize_padded_pack_no_vector_sizes(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
820+
%pad = arith.constant 0.000000e+00 : f32
821+
%pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
822+
return %pack : tensor<32x4x1x16x2xf32>
823+
}
824+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
825+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
826+
// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
827+
// CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
828+
// CHECK-DAG: %[[c15:.*]] = arith.constant 15 : index
829+
// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c7]], %[[c15]] : vector<32x8x16xi1>
830+
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
831+
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
832+
// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
833+
// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
834+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
835+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
836+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
837+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
838+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
839+
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
840+
// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
841+
842+
module attributes {transform.with_named_sequence} {
843+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
844+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
845+
transform.structured.vectorize %0 : !transform.any_op
846+
transform.yield
847+
}
848+
}

0 commit comments

Comments
 (0)