-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add direct vectorization lowering for tensor.pack
ops
#78660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3451206
8ad3ad7
a0931bd
e4950ce
14a73f3
b11affe
6c0e2a1
5c5278c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,18 +19,26 @@ | |||||||||||||||||||||||||||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||||||||||||||||||||||||||||||||
#include "mlir/Dialect/Linalg/Utils/Utils.h" | ||||||||||||||||||||||||||||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||||||||||||||||||||||||||||||||
#include "mlir/Dialect/Tensor/Utils/Utils.h" | ||||||||||||||||||||||||||||||||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||||||||||||||||||||||||||||||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" | ||||||||||||||||||||||||||||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||||||||||||||||||||||||||||||||
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" | ||||||||||||||||||||||||||||||||
#include "mlir/IR/AffineExpr.h" | ||||||||||||||||||||||||||||||||
#include "mlir/IR/Builders.h" | ||||||||||||||||||||||||||||||||
#include "mlir/IR/BuiltinTypeInterfaces.h" | ||||||||||||||||||||||||||||||||
#include "mlir/IR/BuiltinTypes.h" | ||||||||||||||||||||||||||||||||
#include "mlir/IR/OpDefinition.h" | ||||||||||||||||||||||||||||||||
#include "mlir/IR/PatternMatch.h" | ||||||||||||||||||||||||||||||||
#include "mlir/Support/LLVM.h" | ||||||||||||||||||||||||||||||||
#include "mlir/Transforms/RegionUtils.h" | ||||||||||||||||||||||||||||||||
#include "llvm/ADT/STLExtras.h" | ||||||||||||||||||||||||||||||||
#include "llvm/ADT/Sequence.h" | ||||||||||||||||||||||||||||||||
#include "llvm/ADT/SmallVector.h" | ||||||||||||||||||||||||||||||||
#include "llvm/ADT/TypeSwitch.h" | ||||||||||||||||||||||||||||||||
#include "llvm/ADT/iterator_range.h" | ||||||||||||||||||||||||||||||||
#include "llvm/Support/Debug.h" | ||||||||||||||||||||||||||||||||
#include "llvm/Support/MathExtras.h" | ||||||||||||||||||||||||||||||||
#include "llvm/Support/raw_ostream.h" | ||||||||||||||||||||||||||||||||
#include <optional> | ||||||||||||||||||||||||||||||||
#include <type_traits> | ||||||||||||||||||||||||||||||||
|
@@ -1393,6 +1401,164 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, | |||||||||||||||||||||||||||||||
return success(); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
/// Given a tensor::PackOp, return the `dest` shape before any packing | ||||||||||||||||||||||||||||||||
/// permutations. | ||||||||||||||||||||||||||||||||
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp, | ||||||||||||||||||||||||||||||||
ArrayRef<int64_t> destShape) { | ||||||||||||||||||||||||||||||||
return applyPermutation(destShape, | ||||||||||||||||||||||||||||||||
tensor::getPackInverseDestPermutation(packOp)); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
/// Create a TransferReadOp from `source` with static shape `readShape`. If the | ||||||||||||||||||||||||||||||||
/// vector type for the read is not the same as the type of `source`, then a | ||||||||||||||||||||||||||||||||
/// mask is created on the read. | ||||||||||||||||||||||||||||||||
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc, | ||||||||||||||||||||||||||||||||
Value source, ArrayRef<int64_t> readShape, | ||||||||||||||||||||||||||||||||
Value padValue) { | ||||||||||||||||||||||||||||||||
assert(llvm::none_of(readShape, | ||||||||||||||||||||||||||||||||
[](int64_t s) { return s == ShapedType::kDynamic; })); | ||||||||||||||||||||||||||||||||
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape(); | ||||||||||||||||||||||||||||||||
assert(sourceShape.size() == readShape.size()); | ||||||||||||||||||||||||||||||||
auto maskType = VectorType::get(readShape, builder.getI1Type()); | ||||||||||||||||||||||||||||||||
auto vectorType = VectorType::get(readShape, padValue.getType()); | ||||||||||||||||||||||||||||||||
int64_t readRank = readShape.size(); | ||||||||||||||||||||||||||||||||
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); | ||||||||||||||||||||||||||||||||
auto transferReadOp = builder.create<vector::TransferReadOp>( | ||||||||||||||||||||||||||||||||
loc, | ||||||||||||||||||||||||||||||||
/*vectorType=*/vectorType, | ||||||||||||||||||||||||||||||||
/*source=*/source, | ||||||||||||||||||||||||||||||||
/*indices=*/SmallVector<Value>(readRank, zero), | ||||||||||||||||||||||||||||||||
/*padding=*/padValue, | ||||||||||||||||||||||||||||||||
/*inBounds=*/SmallVector<bool>(readRank, true)); | ||||||||||||||||||||||||||||||||
hanhanW marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||
if (llvm::equal(readShape, sourceShape)) { | ||||||||||||||||||||||||||||||||
return transferReadOp; | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
SmallVector<OpFoldResult> mixedSourceDims = | ||||||||||||||||||||||||||||||||
tensor::getMixedSizes(builder, loc, source); | ||||||||||||||||||||||||||||||||
Value mask = | ||||||||||||||||||||||||||||||||
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims); | ||||||||||||||||||||||||||||||||
return mlir::vector::maskOperation(builder, transferReadOp, mask) | ||||||||||||||||||||||||||||||||
->getResult(0); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
/// Given an input, the mixed destSizes, and the vector sizes for vectorization, | ||||||||||||||||||||||||||||||||
/// create an empty destination tensor and create a TransferWriteOp from the | ||||||||||||||||||||||||||||||||
/// input to the empty tensor. If the destination shape is not the same as the | ||||||||||||||||||||||||||||||||
/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a | ||||||||||||||||||||||||||||||||
/// mask for the write. | ||||||||||||||||||||||||||||||||
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc, | ||||||||||||||||||||||||||||||||
Value input, | ||||||||||||||||||||||||||||||||
SmallVector<OpFoldResult> destSizes, | ||||||||||||||||||||||||||||||||
ArrayRef<int64_t> inputVectorSizes) { | ||||||||||||||||||||||||||||||||
auto inputType = cast<VectorType>(input.getType()); | ||||||||||||||||||||||||||||||||
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes, | ||||||||||||||||||||||||||||||||
inputType.getElementType()); | ||||||||||||||||||||||||||||||||
int64_t rank = cast<ShapedType>(dest.getType()).getRank(); | ||||||||||||||||||||||||||||||||
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); | ||||||||||||||||||||||||||||||||
Operation *write = builder.create<vector::TransferWriteOp>( | ||||||||||||||||||||||||||||||||
loc, | ||||||||||||||||||||||||||||||||
/*vector=*/input, | ||||||||||||||||||||||||||||||||
/*source=*/dest, | ||||||||||||||||||||||||||||||||
/*indices=*/SmallVector<Value>(rank, zero), | ||||||||||||||||||||||||||||||||
/*inBounds=*/SmallVector<bool>(rank, true)); | ||||||||||||||||||||||||||||||||
auto destShape = cast<ShapedType>(dest.getType()).getShape(); | ||||||||||||||||||||||||||||||||
assert(llvm::none_of( | ||||||||||||||||||||||||||||||||
destShape.drop_front(inputVectorSizes.size()), | ||||||||||||||||||||||||||||||||
[](int64_t size) { return size == ShapedType::kDynamic; }) && | ||||||||||||||||||||||||||||||||
"Only dims aligned with inputVectorSizes may be dynamic"); | ||||||||||||||||||||||||||||||||
bool needMaskForWrite = !llvm::equal( | ||||||||||||||||||||||||||||||||
inputVectorSizes, destShape.take_front(inputVectorSizes.size())); | ||||||||||||||||||||||||||||||||
if (needMaskForWrite) { | ||||||||||||||||||||||||||||||||
SmallVector<int64_t> writeMaskShape; | ||||||||||||||||||||||||||||||||
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end()); | ||||||||||||||||||||||||||||||||
writeMaskShape.append(destShape.begin() + inputVectorSizes.size(), | ||||||||||||||||||||||||||||||||
destShape.end()); | ||||||||||||||||||||||||||||||||
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type()); | ||||||||||||||||||||||||||||||||
Value maskForWrite = | ||||||||||||||||||||||||||||||||
builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes); | ||||||||||||||||||||||||||||||||
write = mlir::vector::maskOperation(builder, write, maskForWrite); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
return write; | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant | ||||||||||||||||||||||||||||||||
/// padding value into: | ||||||||||||||||||||||||||||||||
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds | ||||||||||||||||||||||||||||||||
/// As in the following example: | ||||||||||||||||||||||||||||||||
/// | ||||||||||||||||||||||||||||||||
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] | ||||||||||||||||||||||||||||||||
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> | ||||||||||||||||||||||||||||||||
/// | ||||||||||||||||||||||||||||||||
/// This pack would be vectorized to: | ||||||||||||||||||||||||||||||||
/// | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a doc |
||||||||||||||||||||||||||||||||
/// %load = vector.mask %mask { | ||||||||||||||||||||||||||||||||
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst | ||||||||||||||||||||||||||||||||
/// {in_bounds = [true, true, true]} : | ||||||||||||||||||||||||||||||||
/// tensor<32x7x16xf32>, vector<32x8x16xf32> | ||||||||||||||||||||||||||||||||
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32> | ||||||||||||||||||||||||||||||||
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32> | ||||||||||||||||||||||||||||||||
/// to vector<32x4x2x1x16xf32> | ||||||||||||||||||||||||||||||||
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2] | ||||||||||||||||||||||||||||||||
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> | ||||||||||||||||||||||||||||||||
/// %write = vector.transfer_write %transpose, | ||||||||||||||||||||||||||||||||
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] | ||||||||||||||||||||||||||||||||
/// {in_bounds = [true, true, true, true, true]} | ||||||||||||||||||||||||||||||||
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> | ||||||||||||||||||||||||||||||||
static LogicalResult | ||||||||||||||||||||||||||||||||
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, | ||||||||||||||||||||||||||||||||
ArrayRef<int64_t> inputVectorSizes, | ||||||||||||||||||||||||||||||||
SmallVectorImpl<Value> &newResults) { | ||||||||||||||||||||||||||||||||
OpBuilder::InsertionGuard g(rewriter); | ||||||||||||||||||||||||||||||||
rewriter.setInsertionPoint(packOp); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Location loc = packOp.getLoc(); | ||||||||||||||||||||||||||||||||
auto padValue = packOp.getPaddingValue(); | ||||||||||||||||||||||||||||||||
if (!padValue) { | ||||||||||||||||||||||||||||||||
padValue = rewriter.create<arith::ConstantOp>( | ||||||||||||||||||||||||||||||||
loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType())); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
ReifiedRankedShapedTypeDims reifiedReturnShapes; | ||||||||||||||||||||||||||||||||
LogicalResult status = | ||||||||||||||||||||||||||||||||
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation()) | ||||||||||||||||||||||||||||||||
.reifyResultShapes(rewriter, reifiedReturnShapes); | ||||||||||||||||||||||||||||||||
(void)status; // prevent unused variable warning on non-assert builds. | ||||||||||||||||||||||||||||||||
assert(succeeded(status) && "failed to reify result shapes"); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
// Create masked TransferReadOp. | ||||||||||||||||||||||||||||||||
SmallVector<int64_t> inputShape(inputVectorSizes); | ||||||||||||||||||||||||||||||||
auto innerTiles = packOp.getStaticInnerTiles(); | ||||||||||||||||||||||||||||||||
auto innerDimsPos = packOp.getInnerDimsPos(); | ||||||||||||||||||||||||||||||||
auto outerDimsPerm = packOp.getOuterDimsPerm(); | ||||||||||||||||||||||||||||||||
if (!outerDimsPerm.empty()) | ||||||||||||||||||||||||||||||||
applyPermutationToVector(inputShape, | ||||||||||||||||||||||||||||||||
invertPermutationVector(outerDimsPerm)); | ||||||||||||||||||||||||||||||||
for (auto [idx, size] : enumerate(innerTiles)) | ||||||||||||||||||||||||||||||||
inputShape[innerDimsPos[idx]] *= size; | ||||||||||||||||||||||||||||||||
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(), | ||||||||||||||||||||||||||||||||
inputShape, padValue); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
// Create ShapeCastOp. | ||||||||||||||||||||||||||||||||
SmallVector<int64_t> destShape(inputVectorSizes); | ||||||||||||||||||||||||||||||||
destShape.append(innerTiles.begin(), innerTiles.end()); | ||||||||||||||||||||||||||||||||
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), | ||||||||||||||||||||||||||||||||
packOp.getDestType().getElementType()); | ||||||||||||||||||||||||||||||||
auto shapeCastOp = | ||||||||||||||||||||||||||||||||
rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
// Create TransposeOp. | ||||||||||||||||||||||||||||||||
auto destPermutation = | ||||||||||||||||||||||||||||||||
invertPermutationVector(tensor::getPackInverseDestPermutation(packOp)); | ||||||||||||||||||||||||||||||||
auto transposeOp = rewriter.create<vector::TransposeOp>( | ||||||||||||||||||||||||||||||||
loc, shapeCastOp.getResult(), destPermutation); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
// Create TransferWriteOp. | ||||||||||||||||||||||||||||||||
Operation *write = | ||||||||||||||||||||||||||||||||
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), | ||||||||||||||||||||||||||||||||
reifiedReturnShapes[0], inputVectorSizes); | ||||||||||||||||||||||||||||||||
newResults.push_back(write->getResult(0)); | ||||||||||||||||||||||||||||||||
return success(); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value | ||||||||||||||||||||||||||||||||
/// and (3) all-zero lowPad to | ||||||||||||||||||||||||||||||||
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. | ||||||||||||||||||||||||||||||||
|
@@ -1402,9 +1568,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, | |||||||||||||||||||||||||||||||
SmallVectorImpl<Value> &newResults) { | ||||||||||||||||||||||||||||||||
auto padValue = padOp.getConstantPaddingValue(); | ||||||||||||||||||||||||||||||||
Location loc = padOp.getLoc(); | ||||||||||||||||||||||||||||||||
int64_t rank = inputVectorSizes.size(); | ||||||||||||||||||||||||||||||||
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type()); | ||||||||||||||||||||||||||||||||
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType()); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value)) | ||||||||||||||||||||||||||||||||
OpBuilder::InsertionGuard g(rewriter); | ||||||||||||||||||||||||||||||||
|
@@ -1416,36 +1579,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, | |||||||||||||||||||||||||||||||
.reifyResultShapes(rewriter, reifiedReturnShapes); | ||||||||||||||||||||||||||||||||
(void)status; // prevent unused variable warning on non-assert builds | ||||||||||||||||||||||||||||||||
assert(succeeded(status) && "failed to reify result shapes"); | ||||||||||||||||||||||||||||||||
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0], | ||||||||||||||||||||||||||||||||
padValue.getType()); | ||||||||||||||||||||||||||||||||
SmallVector<OpFoldResult> mixedSourceDims = | ||||||||||||||||||||||||||||||||
tensor::getMixedSizes(rewriter, loc, padOp.getSource()); | ||||||||||||||||||||||||||||||||
Value mask = | ||||||||||||||||||||||||||||||||
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims); | ||||||||||||||||||||||||||||||||
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); | ||||||||||||||||||||||||||||||||
auto transferReadOp = rewriter.create<vector::TransferReadOp>( | ||||||||||||||||||||||||||||||||
loc, | ||||||||||||||||||||||||||||||||
/*vectorType=*/vectorType, | ||||||||||||||||||||||||||||||||
/*source=*/padOp.getSource(), | ||||||||||||||||||||||||||||||||
/*indices=*/SmallVector<Value>(rank, zero), | ||||||||||||||||||||||||||||||||
/*padding=*/padValue, | ||||||||||||||||||||||||||||||||
/*inBounds=*/SmallVector<bool>(rank, true)); | ||||||||||||||||||||||||||||||||
auto maskedOp = cast<vector::MaskOp>( | ||||||||||||||||||||||||||||||||
mlir::vector::maskOperation(rewriter, transferReadOp, mask)); | ||||||||||||||||||||||||||||||||
Operation *write = rewriter.create<vector::TransferWriteOp>( | ||||||||||||||||||||||||||||||||
loc, | ||||||||||||||||||||||||||||||||
/*vector=*/maskedOp->getResult(0), | ||||||||||||||||||||||||||||||||
/*source=*/emptyOp, | ||||||||||||||||||||||||||||||||
/*indices=*/SmallVector<Value>(rank, zero), | ||||||||||||||||||||||||||||||||
/*inBounds=*/SmallVector<bool>(rank, true)); | ||||||||||||||||||||||||||||||||
bool needMaskForWrite = llvm::any_of( | ||||||||||||||||||||||||||||||||
llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()), | ||||||||||||||||||||||||||||||||
[](auto it) { return std::get<0>(it) != std::get<1>(it); }); | ||||||||||||||||||||||||||||||||
if (needMaskForWrite) { | ||||||||||||||||||||||||||||||||
Value maskForWrite = rewriter.create<vector::CreateMaskOp>( | ||||||||||||||||||||||||||||||||
loc, maskType, reifiedReturnShapes[0]); | ||||||||||||||||||||||||||||||||
write = mlir::vector::maskOperation(rewriter, write, maskForWrite); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(), | ||||||||||||||||||||||||||||||||
inputVectorSizes, padValue); | ||||||||||||||||||||||||||||||||
Operation *write = createWriteOrMaskedWrite( | ||||||||||||||||||||||||||||||||
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes); | ||||||||||||||||||||||||||||||||
newResults.push_back(write->getResult(0)); | ||||||||||||||||||||||||||||||||
return success(); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
@@ -1585,6 +1722,32 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp, | |||||||||||||||||||||||||||||||
return success(); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
/// TODO: Use a matcher to check for a constant padding value. | ||||||||||||||||||||||||||||||||
static LogicalResult | ||||||||||||||||||||||||||||||||
vectorizePackOpPrecondition(tensor::PackOp packOp, | ||||||||||||||||||||||||||||||||
ArrayRef<int64_t> inputVectorSizes) { | ||||||||||||||||||||||||||||||||
auto padValue = packOp.getPaddingValue(); | ||||||||||||||||||||||||||||||||
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) { | ||||||||||||||||||||||||||||||||
LDBG("pad value is not constant: " << packOp << "\n"); | ||||||||||||||||||||||||||||||||
return failure(); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Comment on lines
+1725
to
+1734
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realized there is no matcher for constant float values like there is for constant int values. I didn't know exactly how this should be done. llvm-project/mlir/lib/Dialect/Utils/StaticValueUtils.cpp Lines 108 to 122 in 07bf1dd
For int, getConstantIntValue passes an uninitialized APSInt. For the analogous float matcher, I'm not sure the best way to implement a similar function, since there are several possible floating point semantics. I suppose there could be a util function that simply checks if the value is constant and does not return the value. I think then an arbitrary float semantic could be used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, I think the current implementation is okay, which checks if they are from arith.constant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey Max, I just found this might address the concern. // Return the constant attribute, or null if the Operation isn't a constant.
Attribute getConstantAttr(Operation *constantOp) {
Attribute constant;
matchPattern(value.getDefiningOp(), m_Constant());
return constant;
} |
||||||||||||||||||||||||||||||||
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); | ||||||||||||||||||||||||||||||||
if (failed(isValidMaskedInputVector( | ||||||||||||||||||||||||||||||||
resultTensorShape.take_front(packOp.getSourceRank()), | ||||||||||||||||||||||||||||||||
inputVectorSizes))) | ||||||||||||||||||||||||||||||||
return failure(); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) { | ||||||||||||||||||||||||||||||||
return !getConstantIntValue(v).has_value(); | ||||||||||||||||||||||||||||||||
})) { | ||||||||||||||||||||||||||||||||
LDBG("inner_tiles must be constant: " << packOp << "\n"); | ||||||||||||||||||||||||||||||||
return failure(); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
return success(); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
static LogicalResult | ||||||||||||||||||||||||||||||||
vectorizePadOpPrecondition(tensor::PadOp padOp, | ||||||||||||||||||||||||||||||||
ArrayRef<int64_t> inputVectorSizes) { | ||||||||||||||||||||||||||||||||
|
@@ -1644,6 +1807,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition( | |||||||||||||||||||||||||||||||
.Case<tensor::PadOp>([&](auto padOp) { | ||||||||||||||||||||||||||||||||
return vectorizePadOpPrecondition(padOp, inputVectorSizes); | ||||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||||
.Case<tensor::PackOp>([&](auto packOp) { | ||||||||||||||||||||||||||||||||
return vectorizePackOpPrecondition(packOp, inputVectorSizes); | ||||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||||
.Default([](auto) { return failure(); }); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
@@ -1732,6 +1898,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | |||||||||||||||||||||||||||||||
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, | ||||||||||||||||||||||||||||||||
results); | ||||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||||
.Case<tensor::PackOp>([&](auto packOp) { | ||||||||||||||||||||||||||||||||
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, | ||||||||||||||||||||||||||||||||
results); | ||||||||||||||||||||||||||||||||
}) | ||||||||||||||||||||||||||||||||
.Default([](auto) { return failure(); }); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if (failed(vectorizeResult)) { | ||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'll fix these extra includes in the next round of comments