Skip to content

Commit 2932ae2

Browse files
committed
[mlir][vector] Determine vector sizes from the result shape in the case of tensor pack
When the vector sizes are not passed as inputs to the vector transform operation, the vector sizes are queried from the static result shape in the case of tensor.pack op.
1 parent e47fd09 commit 2932ae2

File tree

3 files changed

+109
-10
lines changed

3 files changed

+109
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,10 +1412,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14121412

14131413
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
14141414
/// vector type for the read is not the same as the type of `source`, then a
1415-
/// mask is created on the read.
1415+
/// mask is created on the read. If `doMasking` parameter is set to false we
1416+
/// update the `inBounds` attribute instead of masking.
14161417
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14171418
Value source, ArrayRef<int64_t> readShape,
1418-
Value padValue) {
1419+
Value padValue, bool doMasking = true) {
14191420
assert(llvm::none_of(readShape,
14201421
[](int64_t s) { return s == ShapedType::kDynamic; }));
14211422
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
@@ -1424,14 +1425,21 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14241425
auto vectorType = VectorType::get(readShape, padValue.getType());
14251426
int64_t readRank = readShape.size();
14261427
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1428+
SmallVector<bool> inBoundsVal(readRank, true);
1429+
if (!doMasking) {
1430+
// Update the inBounds attribute.
1431+
for (unsigned i = 0; i < readRank; i++)
1432+
inBoundsVal[i] = sourceShape[i] == readShape[i];
1433+
}
14271434
auto transferReadOp = builder.create<vector::TransferReadOp>(
14281435
loc,
14291436
/*vectorType=*/vectorType,
14301437
/*source=*/source,
14311438
/*indices=*/SmallVector<Value>(readRank, zero),
14321439
/*padding=*/padValue,
1433-
/*inBounds=*/SmallVector<bool>(readRank, true));
1434-
if (llvm::equal(readShape, sourceShape)) {
1440+
/*inBounds=*/inBoundsVal);
1441+
1442+
if (llvm::equal(readShape, sourceShape) || !doMasking) {
14351443
return transferReadOp;
14361444
}
14371445
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1525,6 +1533,17 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15251533
(void)status; // prevent unused variable warning on non-assert builds.
15261534
assert(succeeded(status) && "failed to reify result shapes");
15271535

1536+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1537+
bool doMasking = true;
1538+
1539+
// If the input vector sizes are not provided, then the vector sizes are
1540+
// determined by the result tensor shape. In case the vector sizes aren't
1541+
// provided, we update the inBounds attribute instead of masking.
1542+
if (inputVectorSizes.empty()) {
1543+
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1544+
doMasking = false;
1545+
}
1546+
15281547
// Create masked TransferReadOp.
15291548
SmallVector<int64_t> inputShape(inputVectorSizes);
15301549
auto innerTiles = packOp.getStaticInnerTiles();
@@ -1536,7 +1555,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15361555
for (auto [idx, size] : enumerate(innerTiles))
15371556
inputShape[innerDimsPos[idx]] *= size;
15381557
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
1539-
inputShape, padValue);
1558+
inputShape, padValue, doMasking);
15401559

15411560
// Create ShapeCastOp.
15421561
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1763,7 +1782,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
17631782
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
17641783
/// given `shape`, i.e., it meets:
17651784
/// 1. The numbers of elements in both array are equal.
1766-
/// 2. `inputVectorSizes` does nos have dynamic dimensions.
1785+
/// 2. `inputVectorSizes` does not have dynamic dimensions.
17671786
/// 3. All the values in `inputVectorSizes` are greater than or equal to
17681787
/// static sizes in `shape`.
17691788
static LogicalResult
@@ -1881,18 +1900,26 @@ static LogicalResult vectorizeLinalgOpPrecondition(
18811900
return success();
18821901
}
18831902

1884-
/// TODO: Use a matcher to check for a constant padding value.
18851903
static LogicalResult
18861904
vectorizePackOpPrecondition(tensor::PackOp packOp,
18871905
ArrayRef<int64_t> inputVectorSizes) {
18881906
auto padValue = packOp.getPaddingValue();
1889-
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
1907+
Attribute cstAttr;
1908+
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
18901909
LDBG("pad value is not constant: " << packOp << "\n");
18911910
return failure();
18921911
}
1893-
18941912
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1895-
if (failed(isValidMaskedInputVector(
1913+
ArrayRef<int64_t> inputTensorShape = packOp.getSourceType().getShape();
1914+
bool satisfyEmptyCond = true;
1915+
if (inputVectorSizes.empty()) {
1916+
if (ShapedType::isDynamicShape(resultTensorShape) ||
1917+
ShapedType::isDynamicShape(inputTensorShape))
1918+
satisfyEmptyCond = false;
1919+
}
1920+
1921+
if (!satisfyEmptyCond &&
1922+
failed(isValidMaskedInputVector(
18961923
resultTensorShape.take_front(packOp.getSourceRank()),
18971924
inputVectorSizes)))
18981925
return failure();

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,20 @@ module attributes {transform.with_named_sequence} {
109109
transform.yield
110110
}
111111
}
112+
113+
// -----
114+
115+
func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
116+
%pad = arith.constant 0.000000e+00 : f32
117+
// expected-error @+1 {{Attempted to vectorize, but failed}}
118+
%pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0] inner_tiles = [16] into %arg1 : tensor<?xf32> -> tensor<4x16xf32>
119+
return %pack : tensor<4x16xf32>
120+
}
121+
122+
module attributes {transform.with_named_sequence} {
123+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
124+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
125+
transform.structured.vectorize %0 : !transform.any_op
126+
transform.yield
127+
}
128+
}

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,58 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
812812
transform.yield
813813
}
814814
}
815+
816+
// -----
817+
818+
// CHECK-LABEL: test_vectorize_pack_no_vector_sizes
819+
func.func @test_vectorize_pack_no_vector_sizes(%arg0: tensor<64x4xf32>, %arg1: tensor<2x4x16x2xf32>) -> tensor<2x4x16x2xf32> {
820+
%pack = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %arg1 : tensor<64x4xf32> -> tensor<2x4x16x2xf32>
821+
return %pack : tensor<2x4x16x2xf32>
822+
}
823+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
824+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
825+
// CHECK: %[[read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[cst]]
826+
// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32>
827+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[read]] : vector<64x4xf32> to vector<4x16x2x2xf32>
828+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [2, 0, 1, 3] : vector<4x16x2x2xf32> to vector<2x4x16x2xf32>
829+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
830+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4x16x2xf32>
831+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
832+
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<2x4x16x2xf32>, tensor<2x4x16x2xf32>
833+
// CHECK: return %[[write]] : tensor<2x4x16x2xf32>
834+
835+
module attributes {transform.with_named_sequence} {
836+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
837+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
838+
transform.structured.vectorize %0 : !transform.any_op
839+
transform.yield
840+
}
841+
}
842+
843+
// -----
844+
845+
// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
846+
func.func @test_vectorize_padded_pack_no_vector_sizes(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
847+
%pad = arith.constant 0.000000e+00 : f32
848+
%pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
849+
return %pack : tensor<32x4x1x16x2xf32>
850+
}
851+
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
852+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
853+
// CHECK: %[[transfer_read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
854+
// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
855+
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[transfer_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
856+
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
857+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
858+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
859+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
860+
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
861+
// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
862+
863+
module attributes {transform.with_named_sequence} {
864+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
865+
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
866+
transform.structured.vectorize %0 : !transform.any_op
867+
transform.yield
868+
}
869+
}

0 commit comments

Comments
 (0)