Skip to content

Commit ce5381e

Browse files
authored
[mlir][vector] Determine vector sizes from the result shape in the ca… (#88249)
…se of tensor pack When the vector sizes are not passed as inputs to the vector transform operation, the vector sizes are queried from the static result shape in the case of tensor.pack op.
1 parent be50a25 commit ce5381e

File tree

3 files changed

+113
-13
lines changed

3 files changed

+113
-13
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,10 +1412,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14121412

14131413
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
14141414
/// vector type for the read is not the same as the type of `source`, then a
1415-
/// mask is created on the read.
1415+
/// mask is created on the read. If `doMasking` parameter is set to false we
1416+
/// update the `inBounds` attribute instead of masking.
14161417
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14171418
Value source, ArrayRef<int64_t> readShape,
1418-
Value padValue) {
1419+
Value padValue, bool doMasking = true) {
14191420
assert(llvm::none_of(readShape,
14201421
[](int64_t s) { return s == ShapedType::kDynamic; }));
14211422
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
@@ -1424,14 +1425,21 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14241425
auto vectorType = VectorType::get(readShape, padValue.getType());
14251426
int64_t readRank = readShape.size();
14261427
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1428+
SmallVector<bool> inBoundsVal(readRank, true);
1429+
if (!doMasking) {
1430+
// Update the inBounds attribute.
1431+
for (unsigned i = 0; i < readRank; i++)
1432+
inBoundsVal[i] = sourceShape[i] == readShape[i];
1433+
}
14271434
auto transferReadOp = builder.create<vector::TransferReadOp>(
14281435
loc,
14291436
/*vectorType=*/vectorType,
14301437
/*source=*/source,
14311438
/*indices=*/SmallVector<Value>(readRank, zero),
14321439
/*padding=*/padValue,
1433-
/*inBounds=*/SmallVector<bool>(readRank, true));
1434-
if (llvm::equal(readShape, sourceShape)) {
1440+
/*inBounds=*/inBoundsVal);
1441+
1442+
if (llvm::equal(readShape, sourceShape) || !doMasking) {
14351443
return transferReadOp;
14361444
}
14371445
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1482,11 +1490,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
14821490
return write;
14831491
}
14841492

1485-
/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1486-
/// padding value into:
1493+
/// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1494+
/// padding value and (3) input vector sizes into:
14871495
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
14881496
/// As in the following example:
1489-
///
14901497
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
14911498
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
14921499
///
@@ -1505,6 +1512,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15051512
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
15061513
/// {in_bounds = [true, true, true, true, true]}
15071514
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1515+
///
1516+
/// If the (3) input vector sizes are not provided, the vector sizes are
1517+
/// determined by the result tensor shape. Also, we update the inBounds
1518+
/// attribute instead of masking.
15081519
static LogicalResult
15091520
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15101521
ArrayRef<int64_t> inputVectorSizes,
@@ -1525,6 +1536,16 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15251536
(void)status; // prevent unused variable warning on non-assert builds.
15261537
assert(succeeded(status) && "failed to reify result shapes");
15271538

1539+
// If the input vector sizes are not provided, then the vector sizes are
1540+
// determined by the result tensor shape. In case the vector sizes aren't
1541+
// provided, we update the inBounds attribute instead of masking.
1542+
bool doMasking = true;
1543+
if (inputVectorSizes.empty()) {
1544+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1545+
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1546+
doMasking = false;
1547+
}
1548+
15281549
// Create masked TransferReadOp.
15291550
SmallVector<int64_t> inputShape(inputVectorSizes);
15301551
auto innerTiles = packOp.getStaticInnerTiles();
@@ -1536,7 +1557,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15361557
for (auto [idx, size] : enumerate(innerTiles))
15371558
inputShape[innerDimsPos[idx]] *= size;
15381559
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
1539-
inputShape, padValue);
1560+
inputShape, padValue, doMasking);
15401561

15411562
// Create ShapeCastOp.
15421563
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1763,7 +1784,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
17631784
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
17641785
/// given `shape`, i.e., it meets:
17651786
/// 1. The numbers of elements in both array are equal.
1766-
/// 2. `inputVectorSizes` does nos have dynamic dimensions.
1787+
/// 2. `inputVectorSizes` does not have dynamic dimensions.
17671788
/// 3. All the values in `inputVectorSizes` are greater than or equal to
17681789
/// static sizes in `shape`.
17691790
static LogicalResult
@@ -1881,18 +1902,25 @@ static LogicalResult vectorizeLinalgOpPrecondition(
18811902
return success();
18821903
}
18831904

1884-
/// TODO: Use a matcher to check for a constant padding value.
18851905
static LogicalResult
18861906
vectorizePackOpPrecondition(tensor::PackOp packOp,
18871907
ArrayRef<int64_t> inputVectorSizes) {
18881908
auto padValue = packOp.getPaddingValue();
1889-
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
1909+
Attribute cstAttr;
1910+
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
18901911
LDBG("pad value is not constant: " << packOp << "\n");
18911912
return failure();
18921913
}
1893-
18941914
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1895-
if (failed(isValidMaskedInputVector(
1915+
bool satisfyEmptyCond = true;
1916+
if (inputVectorSizes.empty()) {
1917+
if (!packOp.getDestType().hasStaticShape() ||
1918+
!packOp.getSourceType().hasStaticShape())
1919+
satisfyEmptyCond = false;
1920+
}
1921+
1922+
if (!satisfyEmptyCond &&
1923+
failed(isValidMaskedInputVector(
18961924
resultTensorShape.take_front(packOp.getSourceRank()),
18971925
inputVectorSizes)))
18981926
return failure();

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,20 @@ module attributes {transform.with_named_sequence} {
109109
transform.yield
110110
}
111111
}
112+
113+
// -----
114+
115+
func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
116+
%pad = arith.constant 0.000000e+00 : f32
117+
// expected-error @+1 {{Attempted to vectorize, but failed}}
118+
%pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0] inner_tiles = [16] into %arg1 : tensor<?xf32> -> tensor<4x16xf32>
119+
return %pack : tensor<4x16xf32>
120+
}
121+
122+
module attributes {transform.with_named_sequence} {
123+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
124+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
125+
transform.structured.vectorize %0 : !transform.any_op
126+
transform.yield
127+
}
128+
}

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,3 +930,58 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
930930
transform.yield
931931
}
932932
}
933+
934+
// -----
935+
936+
// CHECK-LABEL: test_vectorize_pack_no_vector_sizes
937+
func.func @test_vectorize_pack_no_vector_sizes(%arg0: tensor<64x4xf32>, %arg1: tensor<2x4x16x2xf32>) -> tensor<2x4x16x2xf32> {
938+
%pack = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1 : tensor<64x4xf32> -> tensor<2x4x16x2xf32>
939+
return %pack : tensor<2x4x16x2xf32>
940+
}
941+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
942+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
943+
// CHECK: %[[read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[cst]]
944+
// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32>
945+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[read]] : vector<64x4xf32> to vector<4x16x2x2xf32>
946+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [2, 0, 1, 3] : vector<4x16x2x2xf32> to vector<2x4x16x2xf32>
947+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
948+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4x16x2xf32>
949+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
950+
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<2x4x16x2xf32>, tensor<2x4x16x2xf32>
951+
// CHECK: return %[[write]] : tensor<2x4x16x2xf32>
952+
953+
module attributes {transform.with_named_sequence} {
954+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
955+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
956+
transform.structured.vectorize %0 : !transform.any_op
957+
transform.yield
958+
}
959+
}
960+
961+
// -----
962+
963+
// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
964+
func.func @test_vectorize_padded_pack_no_vector_sizes(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
965+
%pad = arith.constant 0.000000e+00 : f32
966+
%pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
967+
return %pack : tensor<32x4x1x16x2xf32>
968+
}
969+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
970+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
971+
// CHECK: %[[transfer_read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
972+
// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
973+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[transfer_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
974+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
975+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
976+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
977+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
978+
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
979+
// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
980+
981+
module attributes {transform.with_named_sequence} {
982+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
983+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
984+
transform.structured.vectorize %0 : !transform.any_op
985+
transform.yield
986+
}
987+
}

0 commit comments

Comments
 (0)