Skip to content

Commit 05c23ec

Browse files
committed
[mlir][vector] Determine vector sizes from the result shape in the case of tensor pack
When the vector sizes are not passed as inputs to the vector transform operation, the vector sizes are queried from the static result shape in the case of tensor.pack op.
1 parent e47fd09 commit 05c23ec

File tree

2 files changed

+88
-9
lines changed

2 files changed

+88
-9
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,10 +1412,12 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14121412

14131413
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
14141414
/// vector type for the read is not the same as the type of `source`, then a
1415-
/// mask is created on the read.
1415+
/// mask is created on the read. `doMasking` specifies whether masking is
1416+
/// required or not. If `doMasking` paramter is set to false we update the
1417+
/// `inBounds` attribute instead.
14161418
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14171419
Value source, ArrayRef<int64_t> readShape,
1418-
Value padValue) {
1420+
Value padValue, bool doMasking = true) {
14191421
assert(llvm::none_of(readShape,
14201422
[](int64_t s) { return s == ShapedType::kDynamic; }));
14211423
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
@@ -1424,14 +1426,21 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14241426
auto vectorType = VectorType::get(readShape, padValue.getType());
14251427
int64_t readRank = readShape.size();
14261428
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1429+
SmallVector<bool> inBoundsVal(readRank, true);
1430+
if (!doMasking) {
1431+
// Update the inBounds attribute.
1432+
for (unsigned i = 0; i < readRank; i++)
1433+
inBoundsVal[i] = sourceShape[i] == readShape[i];
1434+
}
14271435
auto transferReadOp = builder.create<vector::TransferReadOp>(
14281436
loc,
14291437
/*vectorType=*/vectorType,
14301438
/*source=*/source,
14311439
/*indices=*/SmallVector<Value>(readRank, zero),
14321440
/*padding=*/padValue,
1433-
/*inBounds=*/SmallVector<bool>(readRank, true));
1434-
if (llvm::equal(readShape, sourceShape)) {
1441+
/*inBounds=*/inBoundsVal);
1442+
1443+
if (llvm::equal(readShape, sourceShape) || !doMasking) {
14351444
return transferReadOp;
14361445
}
14371446
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1525,6 +1534,20 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15251534
(void)status; // prevent unused variable warning on non-assert builds.
15261535
assert(succeeded(status) && "failed to reify result shapes");
15271536

1537+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1538+
bool doMasking = true;
1539+
1540+
// If the input vector sizes are not provided, then the vector sizes are
1541+
// determined by the result tensor shape. In case the vector sizes aren't
1542+
// provided, we update the inBounds attribute instead of masking.
1543+
if (inputVectorSizes.empty()) {
1544+
// Make sure that the result tensor shape is static.
1545+
if (ShapedType::isDynamicShape(resultTensorShape))
1546+
return failure();
1547+
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1548+
doMasking = false;
1549+
}
1550+
15281551
// Create masked TransferReadOp.
15291552
SmallVector<int64_t> inputShape(inputVectorSizes);
15301553
auto innerTiles = packOp.getStaticInnerTiles();
@@ -1536,7 +1559,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15361559
for (auto [idx, size] : enumerate(innerTiles))
15371560
inputShape[innerDimsPos[idx]] *= size;
15381561
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
1539-
inputShape, padValue);
1562+
inputShape, padValue, doMasking);
15401563

15411564
// Create ShapeCastOp.
15421565
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1763,7 +1786,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
17631786
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
17641787
/// given `shape`, i.e., it meets:
17651788
/// 1. The numbers of elements in both array are equal.
1766-
/// 2. `inputVectorSizes` does nos have dynamic dimensions.
1789+
/// 2. `inputVectorSizes` does not have dynamic dimensions.
17671790
/// 3. All the values in `inputVectorSizes` are greater than or equal to
17681791
/// static sizes in `shape`.
17691792
static LogicalResult
@@ -1881,18 +1904,19 @@ static LogicalResult vectorizeLinalgOpPrecondition(
18811904
return success();
18821905
}
18831906

1884-
/// TODO: Use a matcher to check for a constant padding value.
18851907
static LogicalResult
18861908
vectorizePackOpPrecondition(tensor::PackOp packOp,
18871909
ArrayRef<int64_t> inputVectorSizes) {
18881910
auto padValue = packOp.getPaddingValue();
1889-
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
1911+
Attribute cstAttr;
1912+
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
18901913
LDBG("pad value is not constant: " << packOp << "\n");
18911914
return failure();
18921915
}
18931916

18941917
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1895-
if (failed(isValidMaskedInputVector(
1918+
if (!inputVectorSizes.empty() &&
1919+
failed(isValidMaskedInputVector(
18961920
resultTensorShape.take_front(packOp.getSourceRank()),
18971921
inputVectorSizes)))
18981922
return failure();

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,58 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
812812
transform.yield
813813
}
814814
}
815+
816+
// -----
817+
818+
// CHECK-LABEL: test_vectorize_pack_no_vector_sizes
819+
func.func @test_vectorize_pack_no_vector_sizes(%arg0: tensor<64x4xf32>, %arg1: tensor<2x4x16x2xf32>) -> tensor<2x4x16x2xf32> {
820+
%pack = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1 : tensor<64x4xf32> -> tensor<2x4x16x2xf32>
821+
return %pack : tensor<2x4x16x2xf32>
822+
}
823+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
824+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
825+
// CHECK: %[[read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[cst]]
826+
// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32>
827+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[read]] : vector<64x4xf32> to vector<4x16x2x2xf32>
828+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [2, 0, 1, 3] : vector<4x16x2x2xf32> to vector<2x4x16x2xf32>
829+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
830+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4x16x2xf32>
831+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
832+
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<2x4x16x2xf32>, tensor<2x4x16x2xf32>
833+
// CHECK: return %[[write]] : tensor<2x4x16x2xf32>
834+
835+
module attributes {transform.with_named_sequence} {
836+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
837+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
838+
transform.structured.vectorize %0 : !transform.any_op
839+
transform.yield
840+
}
841+
}
842+
843+
// -----
844+
845+
// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
846+
func.func @test_vectorize_padded_pack_no_vector_sizes(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
847+
%pad = arith.constant 0.000000e+00 : f32
848+
%pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
849+
return %pack : tensor<32x4x1x16x2xf32>
850+
}
851+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
852+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
853+
// CHECK: %[[transfer_read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
854+
// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
855+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[transfer_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
856+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
857+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
858+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
859+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
860+
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
861+
// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
862+
863+
module attributes {transform.with_named_sequence} {
864+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
865+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
866+
transform.structured.vectorize %0 : !transform.any_op
867+
transform.yield
868+
}
869+
}

0 commit comments

Comments
 (0)