Skip to content

Commit 6f033dc

Browse files
committed
[mlir][vector] Determine vector sizes from the result shape in the case of tensor pack
When the vector sizes are not passed as inputs to the vector transform operation, the vector sizes are queried from the static result shape in the case of tensor.pack op.
1 parent e47fd09 commit 6f033dc

File tree

3 files changed

+113
-13
lines changed

3 files changed

+113
-13
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,10 +1412,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14121412

14131413
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
14141414
/// vector type for the read is not the same as the type of `source`, then a
1415-
/// mask is created on the read.
1415+
/// mask is created on the read. If `doMasking` parameter is set to false we
1416+
/// update the `inBounds` attribute instead of masking.
14161417
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14171418
Value source, ArrayRef<int64_t> readShape,
1418-
Value padValue) {
1419+
Value padValue, bool doMasking = true) {
14191420
assert(llvm::none_of(readShape,
14201421
[](int64_t s) { return s == ShapedType::kDynamic; }));
14211422
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
@@ -1424,14 +1425,21 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14241425
auto vectorType = VectorType::get(readShape, padValue.getType());
14251426
int64_t readRank = readShape.size();
14261427
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1428+
SmallVector<bool> inBoundsVal(readRank, true);
1429+
if (!doMasking) {
1430+
// Update the inBounds attribute.
1431+
for (unsigned i = 0; i < readRank; i++)
1432+
inBoundsVal[i] = sourceShape[i] == readShape[i];
1433+
}
14271434
auto transferReadOp = builder.create<vector::TransferReadOp>(
14281435
loc,
14291436
/*vectorType=*/vectorType,
14301437
/*source=*/source,
14311438
/*indices=*/SmallVector<Value>(readRank, zero),
14321439
/*padding=*/padValue,
1433-
/*inBounds=*/SmallVector<bool>(readRank, true));
1434-
if (llvm::equal(readShape, sourceShape)) {
1440+
/*inBounds=*/inBoundsVal);
1441+
1442+
if (llvm::equal(readShape, sourceShape) || !doMasking) {
14351443
return transferReadOp;
14361444
}
14371445
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1482,11 +1490,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
14821490
return write;
14831491
}
14841492

1485-
/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1486-
/// padding value into:
1493+
/// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1494+
/// padding value and (3) input vector sizes into:
14871495
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
14881496
/// As in the following example:
1489-
///
14901497
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
14911498
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
14921499
///
@@ -1505,6 +1512,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15051512
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
15061513
/// {in_bounds = [true, true, true, true, true]}
15071514
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1515+
///
1516+
/// If the (3) input vector sizes are not provided, the vector sizes are
1517+
/// determined by the result tensor shape. Also, we update the inBounds
1518+
/// attribute instead of masking.
15081519
static LogicalResult
15091520
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15101521
ArrayRef<int64_t> inputVectorSizes,
@@ -1525,6 +1536,16 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15251536
(void)status; // prevent unused variable warning on non-assert builds.
15261537
assert(succeeded(status) && "failed to reify result shapes");
15271538

1539+
// If the input vector sizes are not provided, then the vector sizes are
1540+
// determined by the result tensor shape. In case the vector sizes aren't
1541+
// provided, we update the inBounds attribute instead of masking.
1542+
bool doMasking = true;
1543+
if (inputVectorSizes.empty()) {
1544+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1545+
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1546+
doMasking = false;
1547+
}
1548+
15281549
// Create masked TransferReadOp.
15291550
SmallVector<int64_t> inputShape(inputVectorSizes);
15301551
auto innerTiles = packOp.getStaticInnerTiles();
@@ -1536,7 +1557,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15361557
for (auto [idx, size] : enumerate(innerTiles))
15371558
inputShape[innerDimsPos[idx]] *= size;
15381559
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
1539-
inputShape, padValue);
1560+
inputShape, padValue, doMasking);
15401561

15411562
// Create ShapeCastOp.
15421563
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1763,7 +1784,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
17631784
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
17641785
/// given `shape`, i.e., it meets:
17651786
/// 1. The numbers of elements in both array are equal.
1766-
/// 2. `inputVectorSizes` does nos have dynamic dimensions.
1787+
/// 2. `inputVectorSizes` does not have dynamic dimensions.
17671788
/// 3. All the values in `inputVectorSizes` are greater than or equal to
17681789
/// static sizes in `shape`.
17691790
static LogicalResult
@@ -1881,18 +1902,25 @@ static LogicalResult vectorizeLinalgOpPrecondition(
18811902
return success();
18821903
}
18831904

1884-
/// TODO: Use a matcher to check for a constant padding value.
18851905
static LogicalResult
18861906
vectorizePackOpPrecondition(tensor::PackOp packOp,
18871907
ArrayRef<int64_t> inputVectorSizes) {
18881908
auto padValue = packOp.getPaddingValue();
1889-
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
1909+
Attribute cstAttr;
1910+
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
18901911
LDBG("pad value is not constant: " << packOp << "\n");
18911912
return failure();
18921913
}
1893-
18941914
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1895-
if (failed(isValidMaskedInputVector(
1915+
bool satisfyEmptyCond = true;
1916+
if (inputVectorSizes.empty()) {
1917+
if (!packOp.getDestType().hasStaticShape() ||
1918+
!packOp.getSourceType().hasStaticShape())
1919+
satisfyEmptyCond = false;
1920+
}
1921+
1922+
if (!satisfyEmptyCond &&
1923+
failed(isValidMaskedInputVector(
18961924
resultTensorShape.take_front(packOp.getSourceRank()),
18971925
inputVectorSizes)))
18981926
return failure();

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,20 @@ module attributes {transform.with_named_sequence} {
109109
transform.yield
110110
}
111111
}
112+
113+
// -----
114+
115+
func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
116+
%pad = arith.constant 0.000000e+00 : f32
117+
// expected-error @+1 {{Attempted to vectorize, but failed}}
118+
%pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0] inner_tiles = [16] into %arg1 : tensor<?xf32> -> tensor<4x16xf32>
119+
return %pack : tensor<4x16xf32>
120+
}
121+
122+
module attributes {transform.with_named_sequence} {
123+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
124+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
125+
transform.structured.vectorize %0 : !transform.any_op
126+
transform.yield
127+
}
128+
}

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,58 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
812812
transform.yield
813813
}
814814
}
815+
816+
// -----
817+
818+
// CHECK-LABEL: test_vectorize_pack_no_vector_sizes
819+
func.func @test_vectorize_pack_no_vector_sizes(%arg0: tensor<64x4xf32>, %arg1: tensor<2x4x16x2xf32>) -> tensor<2x4x16x2xf32> {
820+
%pack = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1 : tensor<64x4xf32> -> tensor<2x4x16x2xf32>
821+
return %pack : tensor<2x4x16x2xf32>
822+
}
823+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
824+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
825+
// CHECK: %[[read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[cst]]
826+
// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32>
827+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[read]] : vector<64x4xf32> to vector<4x16x2x2xf32>
828+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [2, 0, 1, 3] : vector<4x16x2x2xf32> to vector<2x4x16x2xf32>
829+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
830+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4x16x2xf32>
831+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
832+
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<2x4x16x2xf32>, tensor<2x4x16x2xf32>
833+
// CHECK: return %[[write]] : tensor<2x4x16x2xf32>
834+
835+
module attributes {transform.with_named_sequence} {
836+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
837+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
838+
transform.structured.vectorize %0 : !transform.any_op
839+
transform.yield
840+
}
841+
}
842+
843+
// -----
844+
845+
// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
846+
func.func @test_vectorize_padded_pack_no_vector_sizes(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
847+
%pad = arith.constant 0.000000e+00 : f32
848+
%pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
849+
return %pack : tensor<32x4x1x16x2xf32>
850+
}
851+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
852+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
853+
// CHECK: %[[transfer_read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
854+
// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
855+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[transfer_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
856+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
857+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
858+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
859+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
860+
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
861+
// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
862+
863+
module attributes {transform.with_named_sequence} {
864+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
865+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
866+
transform.structured.vectorize %0 : !transform.any_op
867+
transform.yield
868+
}
869+
}

0 commit comments

Comments
 (0)