Skip to content

Commit 22eac91

Browse files
committed
Use result shape pack vector sizes, clean up
1 parent e08bb0c commit 22eac91

File tree

2 files changed

+119
-82
lines changed

2 files changed

+119
-82
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 104 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2525
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
2626
#include "mlir/IR/AffineExpr.h"
27+
#include "mlir/IR/Builders.h"
2728
#include "mlir/IR/BuiltinTypeInterfaces.h"
2829
#include "mlir/IR/BuiltinTypes.h"
2930
#include "mlir/IR/OpDefinition.h"
@@ -1454,7 +1455,73 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
14541455
return applyPermutation(destShape, invertPermutationVector(perm));
14551456
}
14561457

1457-
///
1458+
/// Create a masked TransferReadOp from `source` with shape `readShape`.
1459+
static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
1460+
Value source,
1461+
ArrayRef<int64_t> readShape,
1462+
Value padValue) {
1463+
auto maskType = VectorType::get(readShape, builder.getI1Type());
1464+
auto vectorType = VectorType::get(readShape, padValue.getType());
1465+
SmallVector<OpFoldResult> mixedSourceDims =
1466+
tensor::getMixedSizes(builder, loc, source);
1467+
Value mask =
1468+
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1469+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1470+
int64_t readRank = readShape.size();
1471+
auto transferReadOp = builder.create<vector::TransferReadOp>(
1472+
loc,
1473+
/*vectorType=*/vectorType,
1474+
/*source=*/source,
1475+
/*indices=*/SmallVector<Value>(readRank, zero),
1476+
/*padding=*/padValue,
1477+
/*inBounds=*/SmallVector<bool>(readRank, true));
1478+
return cast<vector::MaskOp>(
1479+
mlir::vector::maskOperation(builder, transferReadOp, mask));
1480+
}
1481+
1482+
/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1483+
/// create an empty destination tensor and create a TransferWriteOp from the
1484+
/// input to the empty tensor. If the destination shape is not the same as the
1485+
/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1486+
/// mask for the write.
1487+
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1488+
Value input,
1489+
SmallVector<OpFoldResult> destSizes,
1490+
ArrayRef<int64_t> inputVectorSizes) {
1491+
auto inputType = cast<VectorType>(input.getType());
1492+
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1493+
inputType.getElementType());
1494+
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1495+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1496+
Operation *write = builder.create<vector::TransferWriteOp>(
1497+
loc,
1498+
/*vector=*/input,
1499+
/*source=*/dest,
1500+
/*indices=*/SmallVector<Value>(rank, zero),
1501+
/*inBounds=*/SmallVector<bool>(rank, true));
1502+
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1503+
bool needMaskForWrite =
1504+
llvm::any_of(llvm::zip(inputVectorSizes, destShape),
1505+
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
1506+
if (needMaskForWrite) {
1507+
SmallVector<int64_t> writeMaskShape;
1508+
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1509+
writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1510+
destShape.end());
1511+
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1512+
Value maskForWrite =
1513+
builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1514+
write = mlir::vector::maskOperation(builder, write, maskForWrite);
1515+
}
1516+
return write;
1517+
}
1518+
1519+
/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
1520+
/// padding value into
1521+
/// transfer_write_in_bounds(
1522+
/// transpose(
1523+
/// shape_cast(
1524+
/// transfer_read_masked(pack_source, pad_value))))
14581525
static LogicalResult
14591526
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
14601527
ArrayRef<int64_t> inputVectorSizes,
@@ -1468,48 +1535,41 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
14681535
padValue = rewriter.create<arith::ConstantOp>(
14691536
loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
14701537
}
1471-
int64_t inputRank = inputVectorSizes.size();
1472-
int64_t outputRank = packOp.getDestRank();
1473-
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
1474-
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
1475-
14761538
ReifiedRankedShapedTypeDims reifiedReturnShapes;
14771539
LogicalResult status =
14781540
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
14791541
.reifyResultShapes(rewriter, reifiedReturnShapes);
14801542
(void)status; // prevent unused variable warning on non-assert builds
14811543
assert(succeeded(status) && "failed to reify result shapes");
1482-
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
1483-
padValue.getType());
1484-
SmallVector<OpFoldResult> mixedSourceDims =
1485-
tensor::getMixedSizes(rewriter, loc, packOp.getSource());
1486-
Value mask =
1487-
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1488-
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1489-
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1490-
loc,
1491-
/*vectorType=*/vectorType,
1492-
/*source=*/packOp.getSource(),
1493-
/*indices=*/SmallVector<Value>(inputRank, zero),
1494-
/*padding=*/padValue,
1495-
/*inBounds=*/SmallVector<bool>(inputRank, true));
1496-
auto maskedOp = cast<vector::MaskOp>(
1497-
mlir::vector::maskOperation(rewriter, transferReadOp, mask));
1498-
// ShapeCast
1499-
auto tiledPackShape = getTiledPackShape(packOp);
1500-
auto tiledPackType =
1501-
VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
1544+
1545+
// Create masked TransferReadOp
1546+
SmallVector<int64_t> inputShape(inputVectorSizes);
1547+
auto innerTiles = packOp.getStaticInnerTiles();
1548+
auto innerDimsPos = packOp.getInnerDimsPos();
1549+
auto outerDimsPerm = packOp.getOuterDimsPerm();
1550+
if (!outerDimsPerm.empty())
1551+
applyPermutationToVector(inputShape,
1552+
invertPermutationVector(outerDimsPerm));
1553+
for (auto [idx, size] : enumerate(innerTiles))
1554+
inputShape[innerDimsPos[idx]] *= size;
1555+
auto maskedOp = createMaskedTransferRead(rewriter, loc, packOp.getSource(),
1556+
inputShape, padValue);
1557+
1558+
// Create ShapeCastOp
1559+
auto tiledPackType = VectorType::get(getTiledPackShape(packOp),
1560+
packOp.getDestType().getElementType());
15021561
auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
15031562
loc, tiledPackType, maskedOp->getResult(0));
1563+
1564+
// Create TransposeOp
15041565
auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
15051566
auto transposeOp = rewriter.create<vector::TransposeOp>(
1506-
loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
1507-
Operation *write = rewriter.create<vector::TransferWriteOp>(
1508-
loc,
1509-
/*vector=*/transposeOp->getResult(0),
1510-
/*source=*/emptyOp,
1511-
/*indices=*/SmallVector<Value>(outputRank, zero),
1512-
/*inBounds=*/SmallVector<bool>(outputRank, true));
1567+
loc, shapeCastOp.getResult(), tiledShapeToPackedShapePerm);
1568+
1569+
// Create TransferWriteOp
1570+
Operation *write =
1571+
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
1572+
reifiedReturnShapes[0], inputVectorSizes);
15131573
newResults.push_back(write->getResult(0));
15141574
return success();
15151575
}
@@ -1523,9 +1583,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
15231583
SmallVectorImpl<Value> &newResults) {
15241584
auto padValue = padOp.getConstantPaddingValue();
15251585
Location loc = padOp.getLoc();
1526-
int64_t rank = inputVectorSizes.size();
1527-
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
1528-
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
15291586

15301587
// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
15311588
OpBuilder::InsertionGuard g(rewriter);
@@ -1537,36 +1594,11 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
15371594
.reifyResultShapes(rewriter, reifiedReturnShapes);
15381595
(void)status; // prevent unused variable warning on non-assert builds
15391596
assert(succeeded(status) && "failed to reify result shapes");
1540-
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
1541-
padValue.getType());
1542-
SmallVector<OpFoldResult> mixedSourceDims =
1543-
tensor::getMixedSizes(rewriter, loc, padOp.getSource());
1544-
Value mask =
1545-
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1546-
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1547-
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1548-
loc,
1549-
/*vectorType=*/vectorType,
1550-
/*source=*/padOp.getSource(),
1551-
/*indices=*/SmallVector<Value>(rank, zero),
1552-
/*padding=*/padValue,
1553-
/*inBounds=*/SmallVector<bool>(rank, true));
1554-
auto maskedOp = cast<vector::MaskOp>(
1555-
mlir::vector::maskOperation(rewriter, transferReadOp, mask));
1556-
Operation *write = rewriter.create<vector::TransferWriteOp>(
1557-
loc,
1558-
/*vector=*/maskedOp->getResult(0),
1559-
/*source=*/emptyOp,
1560-
/*indices=*/SmallVector<Value>(rank, zero),
1561-
/*inBounds=*/SmallVector<bool>(rank, true));
1562-
bool needMaskForWrite = llvm::any_of(
1563-
llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
1564-
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
1565-
if (needMaskForWrite) {
1566-
Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
1567-
loc, maskType, reifiedReturnShapes[0]);
1568-
write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
1569-
}
1597+
auto maskedOp = createMaskedTransferRead(rewriter, loc, padOp.getSource(),
1598+
inputVectorSizes, padValue);
1599+
Operation *write =
1600+
createWriteOrMaskedWrite(rewriter, loc, maskedOp->getResult(0),
1601+
reifiedReturnShapes[0], inputVectorSizes);
15701602
newResults.push_back(write->getResult(0));
15711603
return success();
15721604
}
@@ -1710,18 +1742,19 @@ static LogicalResult
17101742
vectorizePackOpPrecondition(tensor::PackOp packOp,
17111743
ArrayRef<int64_t> inputVectorSizes) {
17121744
auto padValue = packOp.getPaddingValue();
1713-
if (padValue && getConstantIntValue(padValue) != std::nullopt) {
1745+
if (padValue && !getConstantIntValue(padValue).has_value()) {
17141746
LDBG("pad value is not constant: " << packOp << "\n");
17151747
return failure();
17161748
}
17171749

1718-
ArrayRef<int64_t> resultTensorShape = packOp.getSourceType().getShape();
1719-
if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
1750+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1751+
if (failed(isValidMaskedInputVector(
1752+
resultTensorShape.take_front(packOp.getSourceRank()),
1753+
inputVectorSizes)))
17201754
return failure();
17211755

17221756
if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1723-
std::optional<int64_t> res = getConstantIntValue(v);
1724-
return !res.has_value();
1757+
return !getConstantIntValue(v).has_value();
17251758
})) {
17261759
LDBG("inner_tiles must be constant: " << packOp << "\n");
17271760
return failure();

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,6 @@ func.func @test_masked_vectorize_pad(
426426
{
427427
// CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32
428428
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
429-
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
430429
// CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
431430
// CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
432431
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
@@ -435,7 +434,9 @@ func.func @test_masked_vectorize_pad(
435434
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
436435
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
437436
// CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
438-
// CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]]
437+
// CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
438+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
439+
// CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
439440
// CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32>
440441
%cst = arith.constant 42.43 : f32
441442
%c0 = arith.constant 0 : index
@@ -467,7 +468,6 @@ func.func @test_masked_vectorize_dynamic_pad(
467468
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
468469
// CHECK-DAG: %[[res_d0:.+]] = affine.apply #[[MAP]]()
469470
// CHECK-DAG: %[[res_d1:.+]] = affine.apply #[[MAP]]()
470-
// CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
471471
// CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
472472
// CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
473473
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
@@ -476,9 +476,11 @@ func.func @test_masked_vectorize_dynamic_pad(
476476
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
477477
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
478478
// CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
479+
// CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
480+
// CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
479481
// CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1>
480482
// CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] {
481-
// CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]]
483+
// CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
482484
// CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<?x?xf32>
483485
// CHECK: return %[[masked_write]] : tensor<?x?xf32>
484486
%cst = arith.constant 42.43 : f32
@@ -508,7 +510,7 @@ func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<4x1
508510
module attributes {transform.with_named_sequence} {
509511
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
510512
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
511-
transform.structured.vectorize %0 vector_sizes [8, 16] : !transform.any_op
513+
transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
512514
transform.yield
513515
}
514516
}
@@ -517,15 +519,16 @@ module attributes {transform.with_named_sequence} {
517519
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
518520
// CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?xf32>
519521
// CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?xf32>
520-
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
521522
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<8x16xi1>
522-
// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
523+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
523524
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
524-
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[cst]]
525+
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
525526
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
526527
// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
527528
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
528529
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
530+
// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
531+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
529532
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
530533
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<4x1x16x2xf32>
531534
// CHECK: return %[[write]] : tensor<4x1x16x2xf32>
@@ -539,15 +542,14 @@ func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<32x4x1x
539542
module attributes {transform.with_named_sequence} {
540543
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
541544
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
542-
transform.structured.vectorize %0 vector_sizes [32, 8, 16] : !transform.any_op
545+
transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
543546
transform.yield
544547
}
545548
}
546549
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
547550
// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
548551
// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
549552
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
550-
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
551553
// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c8]], %[[c16]] : vector<32x8x16xi1>
552554
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
553555
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
@@ -556,7 +558,9 @@ module attributes {transform.with_named_sequence} {
556558
// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
557559
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
558560
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
559-
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]]
561+
// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
562+
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
563+
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
560564
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
561565
// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
562566

0 commit comments

Comments
 (0)