Skip to content

Commit d5a0dec

Browse files
committed
Fixed all the issues pointed out by HanHan
and Diego.
1 parent 744a291 commit d5a0dec

File tree

4 files changed

+108
-42
lines changed

4 files changed

+108
-42
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ computeTransposedType(RankedTensorType rankedTensorType,
4141
SmallVector<int64_t> getPackUnPackInverseDestPerm(
4242
std::variant<tensor::PackOp, tensor::UnPackOp> packOp);
4343

44+
/// Unpack requires some packing metadata data, so create another
45+
/// function where this value is passed by reference.
46+
SmallVector<int64_t> getPackUnPackInverseDestPerm(
47+
std::variant<tensor::PackOp, tensor::UnPackOp> packOp,
48+
PackingMetadata &PackingMetadata);
4449
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
4550
/// source tensor or inserts the source tensor into a destination tensor with
4651
/// the same shape.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,11 +1571,12 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15711571
return success();
15721572
}
15731573

1574-
/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
1575-
/// Vector::TransferReadOp - Reads the Vector Array of Source data
1576-
/// vector::TransposeOp - Transpose the Source
1577-
/// ShapeCastOp - Reshapes the data based on the target.
1578-
/// vector::TransferWriteOp. - Write the result vector back.
1574+
/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
1575+
/// Vector::TransferReadOp - Reads a vector from the source tensor
1576+
/// vector::TransposeOp - Transpose the Source tensor
1577+
/// ShapeCastOp - Reshape the data based on the target.
1578+
/// vector::TransferWriteOp. - Write the result vector back to the destination
1579+
/// tensor
15791580
static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
15801581
tensor::UnPackOp unpackOp,
15811582
ArrayRef<int64_t> inputVectorSizes,
@@ -1610,26 +1611,21 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
16101611
LDBG("Unable to reify result shapes of " << unpackOp);
16111612
return failure();
16121613
}
1613-
int64_t unpackRank = unpackTensorType.getRank();
16141614
Location loc = unpackOp->getLoc();
16151615

1616+
// Read result, mask if necessary.
16161617
Value readResult = createReadOrMaskedRead(
16171618
rewriter, loc, unpackOp.getSource(),
16181619
llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()),
16191620
nullptr);
16201621

1621-
int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
1622-
llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
1623-
llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
1624-
PackingMetadata packMetadata =
1625-
computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
1626-
SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
1627-
unpackRank, lastDims, packMetadata.insertPositions);
1622+
PackingMetadata packMetadata;
1623+
SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
1624+
tensor::getPackUnPackInverseDestPerm(unpackOp, packMetadata));
16281625
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
16291626
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
16301627
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
16311628
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1632-
16331629
RankedTensorType stripMineTensorType =
16341630
RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
16351631
.setShape(stripMineShape);
@@ -1646,8 +1642,12 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
16461642
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
16471643
loc, vecCollapsedType, transposeOp->getResult(0));
16481644

1645+
// WriteMaskShape had to match the shapecast shape for dynamic sizes,
1646+
// otherwise the validator complains that the mask size is invalid.
16491647
SmallVector<int64_t> writeMaskShape(
1650-
shapeCastOp.getResultVectorType().getShape());
1648+
unpackOp.getDestType().hasStaticShape()
1649+
? inputVectorSizes
1650+
: shapeCastOp.getResultVectorType().getShape());
16511651
Operation *write =
16521652
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
16531653
reifiedRetShapes[0], writeMaskShape);

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,26 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
7575

7676
SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
7777
std::variant<tensor::PackOp, tensor::UnPackOp> op) {
78+
PackingMetadata pMetaData;
79+
return getPackUnPackInverseDestPerm(op, pMetaData);
80+
}
81+
82+
SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
83+
std::variant<tensor::PackOp, tensor::UnPackOp> op,
84+
PackingMetadata &packingMetadata) {
7885

7986
llvm::ArrayRef<int64_t> innerDimsPos, outerPerm;
80-
RankedTensorType destType;
81-
if (std::holds_alternative<tensor::PackOp>(op)) {
87+
int64_t rank = 0;
88+
bool isPackOp = std::holds_alternative<tensor::PackOp>(op);
89+
if (isPackOp) {
8290
tensor::PackOp packOp = std::get<tensor::PackOp>(op);
8391
innerDimsPos = packOp.getInnerDimsPos();
84-
destType = packOp.getDestType();
92+
rank = packOp.getDestType().getRank();
8593
outerPerm = packOp.getOuterDimsPerm();
8694
} else {
8795
tensor::UnPackOp unpackOp = std::get<tensor::UnPackOp>(op);
8896
innerDimsPos = unpackOp.getInnerDimsPos();
89-
destType = unpackOp.getDestType();
97+
rank = unpackOp.getSourceType().getRank();
9098
outerPerm = unpackOp.getOuterDimsPerm();
9199
}
92100
// The permutation can be obtained from two permutations:
@@ -96,23 +104,21 @@ SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
96104
// has outer_dims_perm.
97105
// Apply (b) permutation on (a) permutation to get the final permutation.
98106
int64_t numPackedDims = innerDimsPos.size();
99-
int64_t packedRank = destType.getRank();
100-
auto lastDims = llvm::to_vector(
101-
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
102-
PackingMetadata packingMetadata =
103-
computePackingMetadata(destType.getRank(), innerDimsPos);
104-
SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
105-
packedRank, lastDims, packingMetadata.insertPositions);
106-
107-
SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
108-
if (!outerPerm.empty())
109-
applyPermutationToVector(outerPos, outerPerm);
110-
SmallVector<int64_t> outerPositionPerm = computePermutationVector(
111-
packedRank, packingMetadata.outerPositions, outerPos);
112-
113-
SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
114-
applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
115-
return packInverseDestPermutation;
107+
auto lastDims =
108+
llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
109+
packingMetadata = computePackingMetadata(rank, innerDimsPos);
110+
SmallVector<int64_t> innerPositionsPerm =
111+
computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
112+
113+
if (isPackOp) {
114+
SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
115+
if (!outerPerm.empty())
116+
applyPermutationToVector(outerPos, outerPerm);
117+
SmallVector<int64_t> outerPositionPerm = computePermutationVector(
118+
rank, packingMetadata.outerPositions, outerPos);
119+
applyPermutationToVector(innerPositionsPerm, outerPositionPerm);
120+
}
121+
return innerPositionsPerm;
116122
}
117123

118124
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -691,12 +691,12 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
691691
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
692692
// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
693693
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
694-
// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
695-
// CHEdCK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<4x16xf32>
696-
// CHEdCK: %[[empt0:.*]] = tensor.empty
697-
// CHEdCK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
698-
// CHEdCK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
699-
// CHEdCK: return %[[write0]]
694+
// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 2, 3, 1] : vector<2x1x16x2xf32> to vector<2x16x2x1xf32>
695+
// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x16x2x1xf32> to vector<32x2xf32>
696+
// CHECK: %[[empt0:.*]] = tensor.empty
697+
// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<32x2xi1>
698+
// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
699+
// CHECK: return %[[write0]]
700700
%ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
701701
return %ret : tensor<?x?xf32>
702702
}
@@ -707,3 +707,58 @@ module attributes {transform.with_named_sequence} {
707707
transform.yield
708708
}
709709
}
710+
711+
// -----
712+
713+
// CHECK-LABEL: func @test_vectorize_unpack
714+
func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
715+
// CHECK: %[[C0:.*]]= arith.constant 0 : index
716+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
717+
// CHECK: %[[C8:.*]] = arith.constant 8 : index
718+
// CHECK: %[[C80:.*]] = arith.constant 8 : index
719+
// CHECK: %[[C32:.*]] = arith.constant 32 : index
720+
// CHECK: %[[C16:.*]] = arith.constant 16 : index
721+
// CHECK: %[[MSK0:.*]] = vector.create_mask %c8, %c8_0, %c32, %c16 : vector<16x8x32x16xi1>
722+
// CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
723+
// CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
724+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
725+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
726+
// CHECK: %[[C01:.*]] = arith.constant 0 : index
727+
// CHECK: %[[C256:.*]] = arith.constant 256 : index
728+
// CHECK: %[[C128:.*]] = arith.constant 128 : index
729+
// CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1>
730+
// CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32>
731+
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
732+
%0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
733+
return %0 : tensor<256x128xf32>
734+
}
735+
module attributes {transform.with_named_sequence} {
736+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
737+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
738+
transform.structured.vectorize %0 vector_sizes [512, 128] : !transform.any_op
739+
transform.yield
740+
}
741+
}
742+
743+
// -----
744+
745+
// CHECK-LABEL: func @test_vectorize_unpack_no_masks
746+
func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
747+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
748+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
749+
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
750+
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
751+
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
752+
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
753+
// CHECK: %[[C00:.*]] = arith.constant 0 : index
754+
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
755+
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
756+
%0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
757+
return %0 : tensor<256x128xf32>
758+
}
759+
module attributes {transform.with_named_sequence} {
760+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
761+
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
762+
transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
763+
transform.yield
764+
} }

0 commit comments

Comments
 (0)