Skip to content

[mlir] Add direct vectorization lowering for tensor.pack ops #78660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);

/// Given a tensor::PackOp, compute the permutation vector to shuffle the
/// packed shape into the shape before any outer or inner permutations have
/// been applied.
/// i.e. for a pack from an ABCD layout to an ABCDba:
/// The packed shape would be ABCDba.
/// The pre-permutation shape would be AaBbCD.
SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);

/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
/// the same shape.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3152,7 +3152,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(

// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
Expand Down
36 changes: 7 additions & 29 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,31 +233,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
rewriter.setInsertionPoint(packOp);

// 2. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied. The permutation
// can be obtained from two permutations:
// a) Compute the permutation vector to move the last `numPackedDims` into
// the `innerPosDims` of a shape of rank `packedRank`.
// b) Compute the permutation vector to move outer dims if the pack op
// has outer_dims_perm.
// Apply (b) permutation on (a) permutation to get the final permutation.
int64_t numPackedDims = packOp.getInnerDimsPos().size();
int64_t packedRank = packedTensorType.getRank();
auto lastDims = llvm::to_vector(
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
// before any outer or inner permutations have been applied.
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
packedRank, lastDims, packingMetadata.insertPositions);

SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
if (!outerPerm.empty())
applyPermutationToVector(outerPos, outerPerm);
SmallVector<int64_t> outerPositionPerm = computePermutationVector(
packedRank, packingMetadata.outerPositions, outerPos);

SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
SmallVector<int64_t> packedToStripMinedShapePerm =
tensor::getPackInverseDestPermutation(packOp);

// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
Expand Down Expand Up @@ -304,10 +284,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
DBGS() << "packedShape: ");
DBGSNL();
llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
DBGS() << "innerPositionsPerm: ");
DBGSNL();
llvm::interleaveComma(packedToStripMinedShapePerm,
DBGS() << "packedToStripMinedShapePerm: ");
DBGSNL(); llvm::interleaveComma(
Expand All @@ -332,9 +308,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
auto emptyOp =
rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
// Offsets.
SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
rewriter.getIndexAttr(0));
// Strides.
SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> ones(packOp.getDestRank(),
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());

Expand Down
236 changes: 203 additions & 33 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,26 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'll fix these extra includes in the next round of comments

#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#include <type_traits>
Expand Down Expand Up @@ -1393,6 +1401,164 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}

/// Given a tensor::PackOp, return the `dest` shape before any packing
/// permutations.
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
ArrayRef<int64_t> destShape) {
return applyPermutation(destShape,
tensor::getPackInverseDestPermutation(packOp));
}

/// Create a TransferReadOp from `source` with static shape `readShape`. If the
/// vector type for the read is not the same as the type of `source`, then a
/// mask is created on the read.
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source, ArrayRef<int64_t> readShape,
Value padValue) {
assert(llvm::none_of(readShape,
[](int64_t s) { return s == ShapedType::kDynamic; }));
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
assert(sourceShape.size() == readShape.size());
auto maskType = VectorType::get(readShape, builder.getI1Type());
auto vectorType = VectorType::get(readShape, padValue.getType());
int64_t readRank = readShape.size();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
auto transferReadOp = builder.create<vector::TransferReadOp>(
loc,
/*vectorType=*/vectorType,
/*source=*/source,
/*indices=*/SmallVector<Value>(readRank, zero),
/*padding=*/padValue,
/*inBounds=*/SmallVector<bool>(readRank, true));
if (llvm::equal(readShape, sourceShape)) {
return transferReadOp;
}
SmallVector<OpFoldResult> mixedSourceDims =
tensor::getMixedSizes(builder, loc, source);
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
->getResult(0);
}

/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
/// create an empty destination tensor and create a TransferWriteOp from the
/// input to the empty tensor. If the destination shape is not the same as the
/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
/// mask for the write.
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
Value input,
SmallVector<OpFoldResult> destSizes,
ArrayRef<int64_t> inputVectorSizes) {
auto inputType = cast<VectorType>(input.getType());
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
inputType.getElementType());
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Operation *write = builder.create<vector::TransferWriteOp>(
loc,
/*vector=*/input,
/*source=*/dest,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/SmallVector<bool>(rank, true));
auto destShape = cast<ShapedType>(dest.getType()).getShape();
assert(llvm::none_of(
destShape.drop_front(inputVectorSizes.size()),
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
"Only dims aligned with inputVectorSizes may be dynamic");
bool needMaskForWrite = !llvm::equal(
inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
if (needMaskForWrite) {
SmallVector<int64_t> writeMaskShape;
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
destShape.end());
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
Value maskForWrite =
builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
write = mlir::vector::maskOperation(builder, write, maskForWrite);
}
return write;
}

/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
/// padding value into:
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
/// As in the following example:
///
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
///
/// This pack would be vectorized to:
///
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a doc

/// %load = vector.mask %mask {
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
/// {in_bounds = [true, true, true]} :
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
/// to vector<32x4x2x1x16xf32>
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
/// %write = vector.transfer_write %transpose,
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
/// {in_bounds = [true, true, true, true, true]}
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);

Location loc = packOp.getLoc();
auto padValue = packOp.getPaddingValue();
if (!padValue) {
padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
}
ReifiedRankedShapedTypeDims reifiedReturnShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds.
assert(succeeded(status) && "failed to reify result shapes");

// Create masked TransferReadOp.
SmallVector<int64_t> inputShape(inputVectorSizes);
auto innerTiles = packOp.getStaticInnerTiles();
auto innerDimsPos = packOp.getInnerDimsPos();
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty())
applyPermutationToVector(inputShape,
invertPermutationVector(outerDimsPerm));
for (auto [idx, size] : enumerate(innerTiles))
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
inputShape, padValue);

// Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
destShape.append(innerTiles.begin(), innerTiles.end());
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
packOp.getDestType().getElementType());
auto shapeCastOp =
rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);

// Create TransposeOp.
auto destPermutation =
invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
auto transposeOp = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getResult(), destPermutation);

// Create TransferWriteOp.
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
reifiedReturnShapes[0], inputVectorSizes);
newResults.push_back(write->getResult(0));
return success();
}

/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
Expand All @@ -1402,9 +1568,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
SmallVectorImpl<Value> &newResults) {
auto padValue = padOp.getConstantPaddingValue();
Location loc = padOp.getLoc();
int64_t rank = inputVectorSizes.size();
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());

// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
OpBuilder::InsertionGuard g(rewriter);
Expand All @@ -1416,36 +1579,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
padValue.getType());
SmallVector<OpFoldResult> mixedSourceDims =
tensor::getMixedSizes(rewriter, loc, padOp.getSource());
Value mask =
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc,
/*vectorType=*/vectorType,
/*source=*/padOp.getSource(),
/*indices=*/SmallVector<Value>(rank, zero),
/*padding=*/padValue,
/*inBounds=*/SmallVector<bool>(rank, true));
auto maskedOp = cast<vector::MaskOp>(
mlir::vector::maskOperation(rewriter, transferReadOp, mask));
Operation *write = rewriter.create<vector::TransferWriteOp>(
loc,
/*vector=*/maskedOp->getResult(0),
/*source=*/emptyOp,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/SmallVector<bool>(rank, true));
bool needMaskForWrite = llvm::any_of(
llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
if (needMaskForWrite) {
Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
loc, maskType, reifiedReturnShapes[0]);
write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
}
auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(),
inputVectorSizes, padValue);
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
newResults.push_back(write->getResult(0));
return success();
}
Expand Down Expand Up @@ -1585,6 +1722,32 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success();
}

/// TODO: Use a matcher to check for a constant padding value.
static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}

Comment on lines +1725 to +1734
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized there is no matcher for constant float values like there is for constant int values. I didn't know exactly how this should be done.

/// If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
APSInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
return intVal.getSExtValue();
return std::nullopt;
}
// Case 2: Check for IntegerAttr.
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
}

For int, getConstantIntValue passes an uninitialized APSInt. For the analogous float matcher, I'm not sure the best way to implement a similar function, since there are several possible floating point semantics. I suppose there could be a util function that simply checks if the value is constant and does not return the value. I think then an arbitrary float semantic could be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, I think the current implementation is okay, which checks if they are from arith.constant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Max, I just found this might address the concern.

// Return the constant attribute, or null if the Operation isn't a constant.
Attribute getConstantAttr(Operation *constantOp) {
  Attribute constant;
  matchPattern(value.getDefiningOp(), m_Constant());
  return constant;
}

https://mlir.llvm.org/getting_started/Faq/#many-dialects-define-a-constant-operation-how-do-i-get-a-constant-value-generically

ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
if (failed(isValidMaskedInputVector(
resultTensorShape.take_front(packOp.getSourceRank()),
inputVectorSizes)))
return failure();

if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
return !getConstantIntValue(v).has_value();
})) {
LDBG("inner_tiles must be constant: " << packOp << "\n");
return failure();
}

return success();
}

static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
Expand Down Expand Up @@ -1644,6 +1807,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
.Case<tensor::PackOp>([&](auto packOp) {
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
})
.Default([](auto) { return failure(); });
}

Expand Down Expand Up @@ -1732,6 +1898,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
results);
})
.Case<tensor::PackOp>([&](auto packOp) {
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
results);
})
.Default([](auto) { return failure(); });

if (failed(vectorizeResult)) {
Expand Down
Loading