-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Add more patterns to Vector Linearize transformation #136193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector Author: Nishant Patel (nbpatel) ChangesThis PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors. Patch is 40.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136193.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..6de5d0c5a101e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -27,6 +28,10 @@
using namespace mlir;
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+ // For BW-0, all operations are legal
+ if (targetBitWidth == 0) {
+ return false;
+ }
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
@@ -273,6 +278,77 @@ struct LinearizeVectorExtractStridedSlice final
unsigned targetVectorBitWidth;
};
+/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the
+/// source vector using ExtractStridedSliceOp and inserting them into the
+/// destination vector using InsertStridedSliceOp.
+/// Following,
+/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+/// is converted to :
+/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
+/// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+/// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+struct LinearizeVectorInsertStridedSlice final
+ : public OpConversionPattern<vector::InsertStridedSliceOp> {
+ using OpConversionPattern<
+ vector::InsertStridedSliceOp>::OpConversionPattern;
+ LinearizeVectorInsertStridedSlice(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto srcTy = op.getSourceVectorType();
+ auto dstTy = op.getDestVectorType();
+
+ if (op.hasNonUnitStrides()) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linearization only supports unit strides.");
+ }
+
+ if (srcTy.getRank() != 2) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linearization only supports 2D source.");
+ }
+
+ if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linerization only supports static shapes.");
+ }
+
+ auto dstShape = dstTy.getShape();
+ auto dstStrides = dstShape.drop_front().vec();
+ dstStrides.push_back(1);
+ int64_t linearizedOffset = 0;
+ for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
+ linearizedOffset += getConstantIntValue(off).value() * stride;
+ }
+
+ // extracts a row from source, and insert it into the destination
+ auto srcShape = srcTy.getShape();
+ Value dstValue = adaptor.getDest();
+ for (auto i = 0; i < srcShape[0]; i++) {
+ auto srcOffset = i * srcShape[1];
+ auto value = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1);
+
+ auto dstOffset = linearizedOffset + i * dstShape.back();
+ dstValue = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, value, dstValue, dstOffset, 1);
+ }
+
+ rewriter.replaceOp(op, dstValue);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
/// This pattern converts the ShuffleOp that works on nD (n > 1)
/// vectors to a ShuffleOp that works on linearized vectors.
/// Following,
@@ -369,6 +445,11 @@ struct LinearizeVectorExtract final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Skip if result is not a vector type
+ if (!isa<VectorType>(extractOp.getType()))
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalar extract is not supported.");
+
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(extractOp,
@@ -531,12 +612,312 @@ struct LinearizeVectorBitCast final
unsigned targetVectorBitWidth;
};
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %result = arith.constant dense<0.0> : vector<4x4xf32>
+/// %slice_0 = vector.load %base[%indices] : vector<4xf32>
+/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+/// ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final
+ : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+ LinearizeVectorLoad(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = loadOp->getLoc();
+ auto vecType = loadOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+ auto unrollCount = shape[0];
+ auto vecSize = shape[1];
+ auto newVecType =
+ VectorType::get({vecSize}, vecType.getElementType());
+
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ // Construct the 2D vector.
+ Value resultVec = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(vecType));
+ // Emit unrolled loads for each 1D vector slice.
+ for (auto i = 0; i < unrollCount; i++) {
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex =
+ rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ auto vec = rewriter.create<vector::LoadOp>(
+ loc, newVecType, adaptor.getBase(), indices);
+ resultVec =
+ rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+ }
+
+ rewriter.replaceOp(loadOp, resultVec);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %slice_0 = vector.extract %source[0] : vector<4xf32>
+/// vector.store %slice_0, %base[%indices] : vector<4xf32>
+/// %slice_1 = vector.extract %source[1] : vector<4xf32>
+/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+/// ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
+/// slices from the source vector and storing them into the destination.
+/// The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+ LinearizeVectorStore(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = storeOp->getLoc();
+ auto vecType = storeOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+
+ auto unrollCount = shape[0];
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ auto vec = rewriter.create<vector::ShapeCastOp>(
+ loc, vecType, adaptor.getValueToStore());
+
+ for (auto i = 0; i < unrollCount; i++) {
+ auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex =
+ rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+ indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+/// vector.splat %value : vector<4x4xf32>
+/// is converted to:
+/// %out_1d = vector.splat %value : vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorSplat(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(
+ splatOp, adaptor.getInput(), dstTy);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+/// %out_1d = vector.create_mask %dims : vector<4xi1>
+/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+ createMaskOp, dstTy, adaptor.getOperands().back());
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
+ : public OpInterfaceConversionPattern<RegionBranchOpInterface> {
+ using OpInterfaceConversionPattern<
+ RegionBranchOpInterface>::OpInterfaceConversionPattern;
+
+ LinearizeRegionBranchOp(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(RegionBranchOpInterface op,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto converter = getTypeConverter();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.startOpModification(op);
+
+ llvm::SmallVector<Type> convertedTypes;
+ for (Type ty : op->getResultTypes()) {
+ convertedTypes.push_back(converter->convertType(ty));
+ }
+
+ if (convertedTypes == op->getResultTypes() &&
+ op->getOperands() == operands) {
+ return failure();
+ }
+
+ op->setOperands(operands);
+
+ // Convert region types (block arguments and yields)
+ for (Region ®ion : op->getRegions()) {
+ if (failed(rewriter.convertRegionTypes(®ion, *converter))) {
+ return failure();
+ }
+
+ // Process yields within each region
+ for (Block &block : region) {
+ if (auto *terminator = block.getTerminator()) {
+ for (OpOperand &yieldOperand : terminator->getOpOperands()) {
+ Value value = yieldOperand.get();
+ Type type = value.getType();
+ if (!converter->isLegal(type)) {
+ Type newTy = converter->convertType(type);
+ rewriter.setInsertionPoint(terminator);
+ Value newValue =
+ rewriter.create<vector::ShapeCastOp>(loc, newTy, value);
+ yieldOperand.set(newValue);
+ }
+ }
+ }
+ }
+ }
+
+ // Update result types
+ rewriter.setInsertionPointAfter(op);
+ llvm::SmallVector<Value> newResults;
+ for (Value result : op->getResults()) {
+ Type oldTy = result.getType();
+ if (!converter->isLegal(oldTy)) {
+ Type newTy = converter->convertType(oldTy);
+ result.setType(newTy);
+ Operation *castOp =
+ rewriter.create<vector::ShapeCastOp>(loc, oldTy, result);
+ result.replaceAllUsesExcept(castOp->getResult(0), castOp);
+ newResults.push_back(castOp->getResult(0));
+ } else {
+ newResults.push_back(result);
+ }
+ }
+
+ rewriter.finalizeOpModification(op);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth) {
+ typeConverter.addConversion([](Type type) -> Type { return type; });
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
return type;
@@ -555,9 +936,12 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+ target.addLegalOp<mlir::vector::ShapeCastOp>();
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp>(op) ||
+ if ((isa<vector::BitCastOp, vector::LoadOp,
+ vector::StoreOp, vector::CreateMaskOp,
+ RegionBranchOpInterface, vector::SplatOp>(op) ||
op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,7 +952,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
});
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+ LinearizeVectorBitCast, LinearizeVectorLoad,
+ LinearizeVectorStore, LinearizeVectorSplat,
+ LinearizeVectorCreateMask, LinearizeRegionBranchOp
+ >(typeConverter, patterns.getContext(),
targetBitWidth);
}
@@ -583,7 +970,21 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
.getRank() == 1)
: true;
});
+
+ target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+ [=](vector::InsertStridedSliceOp op) -> bool {
+ if(isLessThanTargetBitWidth(op, targetBitWidth)) {
+ auto srcTy = op.getSourceVectorType();
+ auto dstTy = op.getDestVectorType();
+ if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+ srcTy.hasStaticShape() && dstTy.hasStaticShape())
+ return false;
+ }
+ return true;
+ });
+
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+ LinearizeVectorInsertStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..e47e7c4a84d68 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,338 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
+
+// -----
+// ALL-LABEL: test_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+ // BW-128: %[[C1:.*]] = arith.constant 1 : index
+ // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+ // BW-128: %[[C2:.*]] = arith.constant 2 : index
+...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesThis PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors. Patch is 40.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136193.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..6de5d0c5a101e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -27,6 +28,10 @@
using namespace mlir;
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+ // For BW-0, all operations are legal
+ if (targetBitWidth == 0) {
+ return false;
+ }
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
@@ -273,6 +278,77 @@ struct LinearizeVectorExtractStridedSlice final
unsigned targetVectorBitWidth;
};
+/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the
+/// source vector using ExtractStridedSliceOp and inserting them into the
+/// destination vector using InsertStridedSliceOp.
+/// Following,
+/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+/// is converted to :
+/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
+/// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+/// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+struct LinearizeVectorInsertStridedSlice final
+ : public OpConversionPattern<vector::InsertStridedSliceOp> {
+ using OpConversionPattern<
+ vector::InsertStridedSliceOp>::OpConversionPattern;
+ LinearizeVectorInsertStridedSlice(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto srcTy = op.getSourceVectorType();
+ auto dstTy = op.getDestVectorType();
+
+ if (op.hasNonUnitStrides()) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linearization only supports unit strides.");
+ }
+
+ if (srcTy.getRank() != 2) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linearization only supports 2D source.");
+ }
+
+ if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ op, "InsertStridedSliceOp linerization only supports static shapes.");
+ }
+
+ auto dstShape = dstTy.getShape();
+ auto dstStrides = dstShape.drop_front().vec();
+ dstStrides.push_back(1);
+ int64_t linearizedOffset = 0;
+ for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
+ linearizedOffset += getConstantIntValue(off).value() * stride;
+ }
+
+ // extracts a row from source, and insert it into the destination
+ auto srcShape = srcTy.getShape();
+ Value dstValue = adaptor.getDest();
+ for (auto i = 0; i < srcShape[0]; i++) {
+ auto srcOffset = i * srcShape[1];
+ auto value = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1);
+
+ auto dstOffset = linearizedOffset + i * dstShape.back();
+ dstValue = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, value, dstValue, dstOffset, 1);
+ }
+
+ rewriter.replaceOp(op, dstValue);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
/// This pattern converts the ShuffleOp that works on nD (n > 1)
/// vectors to a ShuffleOp that works on linearized vectors.
/// Following,
@@ -369,6 +445,11 @@ struct LinearizeVectorExtract final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Skip if result is not a vector type
+ if (!isa<VectorType>(extractOp.getType()))
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalar extract is not supported.");
+
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(extractOp,
@@ -531,12 +612,312 @@ struct LinearizeVectorBitCast final
unsigned targetVectorBitWidth;
};
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %result = arith.constant dense<0.0> : vector<4x4xf32>
+/// %slice_0 = vector.load %base[%indices] : vector<4xf32>
+/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+/// ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final
+ : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+ LinearizeVectorLoad(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = loadOp->getLoc();
+ auto vecType = loadOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+ auto unrollCount = shape[0];
+ auto vecSize = shape[1];
+ auto newVecType =
+ VectorType::get({vecSize}, vecType.getElementType());
+
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ // Construct the 2D vector.
+ Value resultVec = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(vecType));
+ // Emit unrolled loads for each 1D vector slice.
+ for (auto i = 0; i < unrollCount; i++) {
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex =
+ rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ auto vec = rewriter.create<vector::LoadOp>(
+ loc, newVecType, adaptor.getBase(), indices);
+ resultVec =
+ rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+ }
+
+ rewriter.replaceOp(loadOp, resultVec);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %slice_0 = vector.extract %source[0] : vector<4xf32>
+/// vector.store %slice_0, %base[%indices] : vector<4xf32>
+/// %slice_1 = vector.extract %source[1] : vector<4xf32>
+/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+/// ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
+/// slices from the source vector and storing them into the destination.
+/// The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+ LinearizeVectorStore(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = storeOp->getLoc();
+ auto vecType = storeOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+
+ auto unrollCount = shape[0];
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ auto vec = rewriter.create<vector::ShapeCastOp>(
+ loc, vecType, adaptor.getValueToStore());
+
+ for (auto i = 0; i < unrollCount; i++) {
+ auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex =
+ rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+ indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+/// vector.splat %value : vector<4x4xf32>
+/// is converted to:
+/// %out_1d = vector.splat %value : vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorSplat(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(
+ splatOp, adaptor.getInput(), dstTy);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+/// %out_1d = vector.create_mask %dims : vector<4xi1>
+/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+ createMaskOp, dstTy, adaptor.getOperands().back());
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
+ : public OpInterfaceConversionPattern<RegionBranchOpInterface> {
+ using OpInterfaceConversionPattern<
+ RegionBranchOpInterface>::OpInterfaceConversionPattern;
+
+ LinearizeRegionBranchOp(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(RegionBranchOpInterface op,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto converter = getTypeConverter();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.startOpModification(op);
+
+ llvm::SmallVector<Type> convertedTypes;
+ for (Type ty : op->getResultTypes()) {
+ convertedTypes.push_back(converter->convertType(ty));
+ }
+
+ if (convertedTypes == op->getResultTypes() &&
+ op->getOperands() == operands) {
+ return failure();
+ }
+
+ op->setOperands(operands);
+
+ // Convert region types (block arguments and yields)
+ for (Region ®ion : op->getRegions()) {
+ if (failed(rewriter.convertRegionTypes(®ion, *converter))) {
+ return failure();
+ }
+
+ // Process yields within each region
+ for (Block &block : region) {
+ if (auto *terminator = block.getTerminator()) {
+ for (OpOperand &yieldOperand : terminator->getOpOperands()) {
+ Value value = yieldOperand.get();
+ Type type = value.getType();
+ if (!converter->isLegal(type)) {
+ Type newTy = converter->convertType(type);
+ rewriter.setInsertionPoint(terminator);
+ Value newValue =
+ rewriter.create<vector::ShapeCastOp>(loc, newTy, value);
+ yieldOperand.set(newValue);
+ }
+ }
+ }
+ }
+ }
+
+ // Update result types
+ rewriter.setInsertionPointAfter(op);
+ llvm::SmallVector<Value> newResults;
+ for (Value result : op->getResults()) {
+ Type oldTy = result.getType();
+ if (!converter->isLegal(oldTy)) {
+ Type newTy = converter->convertType(oldTy);
+ result.setType(newTy);
+ Operation *castOp =
+ rewriter.create<vector::ShapeCastOp>(loc, oldTy, result);
+ result.replaceAllUsesExcept(castOp->getResult(0), castOp);
+ newResults.push_back(castOp->getResult(0));
+ } else {
+ newResults.push_back(result);
+ }
+ }
+
+ rewriter.finalizeOpModification(op);
+ return success();
+ }
+ private:
+ unsigned targetVectorBitWidth;
+};
+
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth) {
+ typeConverter.addConversion([](Type type) -> Type { return type; });
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
return type;
@@ -555,9 +936,12 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
+ target.addLegalOp<mlir::vector::ShapeCastOp>();
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp>(op) ||
+ if ((isa<vector::BitCastOp, vector::LoadOp,
+ vector::StoreOp, vector::CreateMaskOp,
+ RegionBranchOpInterface, vector::SplatOp>(op) ||
op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,7 +952,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
});
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+ LinearizeVectorBitCast, LinearizeVectorLoad,
+ LinearizeVectorStore, LinearizeVectorSplat,
+ LinearizeVectorCreateMask, LinearizeRegionBranchOp
+ >(typeConverter, patterns.getContext(),
targetBitWidth);
}
@@ -583,7 +970,21 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
.getRank() == 1)
: true;
});
+
+ target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+ [=](vector::InsertStridedSliceOp op) -> bool {
+ if(isLessThanTargetBitWidth(op, targetBitWidth)) {
+ auto srcTy = op.getSourceVectorType();
+ auto dstTy = op.getDestVectorType();
+ if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+ srcTy.hasStaticShape() && dstTy.hasStaticShape())
+ return false;
+ }
+ return true;
+ });
+
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+ LinearizeVectorInsertStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..e47e7c4a84d68 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,338 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
+
+// -----
+// ALL-LABEL: test_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+ // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+ // BW-128: %[[C1:.*]] = arith.constant 1 : index
+ // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+ // BW-128: %[[C2:.*]] = arith.constant 2 : index
+...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
@Hardcode84 @charithaintc @chencha3 Please take a look as well |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few drive by comments. I'm no expert on this, so please ignore my suggestions where not appropriate
@@ -27,6 +27,10 @@ | |||
using namespace mlir; | |||
|
|||
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { | |||
// For BW-0, all operations are legal | |||
if (targetBitWidth == 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
targetBitWidth = std::numeric_limits<unsigned>::max()
is used in places, please consolidate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies I wasn't clear in this comment. I meant to consolidate use of targetBitWidth = 0
and targetBitWidth = 'max'
. Is this a workaround for adding 1 to std::numeric_limits<unsigned>::max()
?
I would like to commit #136581 which would mean this logic doesn't live here anymore and this comment wouldn't be relevant, could you please take a look at that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I briefly looked at it. I'm ok with that change, but can we commit this first if possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reached a point where I need to land this to make progress with other development, I hope that's ok @nbpatel
It should be easy to absorb the changes (just remove code related to bitwidth)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes go ahead and merge it
/// vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : | ||
/// vector<4xf32> from vector<8xf32> %3 = vector.insert_strided_slice %2, %1 | ||
/// {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32> | ||
struct LinearizeVectorInsertStridedSlice final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it's possible to reuse
llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
Line 84 in e1bb7f6
class ConvertSameRankInsertStridedSliceIntoShuffle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think since this is not a pass and just a bunch of patterns, users can decide how they want to lower from this point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like it if this were a collection of patterns that the user could opt in or out of one-by-one, but currently there are only 2 APIS exposed to the user (populateVectorLinearizeShuffleLikeOpsPatterns and populateVectorLinearizeTypeConversionsAndLegality)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the users can run populateVectorLinearizeTypeConversionsAndLegality and populateVectorInsertExtractStridedSliceTransforms to convert to shuffle
// Skip if result is not a vector type | ||
if (!isa<VectorType>(extractOp.getType())) | ||
return rewriter.notifyMatchFailure(extractOp, | ||
"scalar extract is not supported."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"scalar extract is not supported."); | |
"scalar extract is not supported, because ..."); |
might to helpful!
@@ -531,12 +618,239 @@ struct LinearizeVectorBitCast final | |||
unsigned targetVectorBitWidth; | |||
}; | |||
|
|||
/// This pattern converts the LoadOp to a series of LoadOp & InsertOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
formatting.
/// Following, | ||
/// vector.load %base[%indices] : vector<4x4xf32> | ||
/// is converted to : | ||
/// %result = arith.constant dense<0.0> : vector<4x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would be beneficial to flatten/linearize out all contiguous dimensions first. i.e. if the load is actually unstrided, like
%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8x100xf32>
if this flattened to
%result = vector.load %flat_base[%i] : memref<10000xf32>, vector<800xf32>
the IR generated wouldn't be unrolled.
It seems to me like this is more unrolling than linearizing?
llvm-project/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
Line 305 in 52a5332
void populateVectorUnrollPatterns(RewritePatternSet &patterns, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a more general solution for unstrided/strided memrefs and we can always fuse the loads later on as an optimization
@@ -27,6 +27,10 @@ | |||
using namespace mlir; | |||
|
|||
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { | |||
// For BW-0, all operations are legal | |||
if (targetBitWidth == 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies I wasn't clear in this comment. I meant to consolidate use of targetBitWidth = 0
and targetBitWidth = 'max'
. Is this a workaround for adding 1 to std::numeric_limits<unsigned>::max()
?
I would like to commit #136581 which would mean this logic doesn't live here anymore and this comment wouldn't be relevant, could you please take a look at that?
/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into | ||
/// vector<4x4xf32> | ||
/// is converted to : | ||
/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check formatting. Maybe clang-format off
and clang-format on
will be helpful?
Hi @nbpatel I think these patterns will all be useful for users of the vector dialect, but I have a few requests that I would like to make first:
I know they're large requests, my apologies for not suggesting them earlier in the process. I think your contributions will be significantly more useful if done this way. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
This is adding 4 unrelated patterns - could you split this into independent PRs?
auto loc = loadOp->getLoc(); | ||
auto vecType = loadOp.getVectorType(); | ||
auto shape = vecType.getShape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you spell out the types here and other places? For reference, here are LLVM's guidelines for using auto
:
// ----- | ||
// ALL-LABEL: test_vector_load | ||
// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>) | ||
func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please avoid using test
in test function names - that's unnecessary noise. I appreciate that you are trying to follow the existing convention in this file, but IMO we should focus on encoding unique information. Here are our guidelines:
So, what makes this test unique?
// DEFAULT: %[[C1:.*]] = arith.constant 1 : index | ||
// BW-128: %[[C1:.*]] = arith.constant 1 : index | ||
// DEFAULT: %[[C2:.*]] = arith.constant 2 : index | ||
// BW-128: %[[C2:.*]] = arith.constant 2 : index | ||
// DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> | ||
// BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> | ||
// DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
// BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
// DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
// BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
// DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index | ||
// BW-128: %[[C1_0:.*]] = arith.constant 1 : index | ||
// DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index | ||
// BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index | ||
// DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
// BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
// DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
// BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
// DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index | ||
// BW-128: %[[C2_1:.*]] = arith.constant 2 : index | ||
// DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index | ||
// BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index | ||
// DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
// BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
// DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
// BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> | ||
// DEFAULT: %[[C3:.*]] = arith.constant 3 : index | ||
// BW-128: %[[C3:.*]] = arith.constant 3 : index | ||
// DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index | ||
// BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index | ||
// DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
// BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> | ||
// DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> | ||
// BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> | ||
// DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> | ||
// BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> | ||
// DEFAULT: return %[[CAST]] : vector<4x4xf16> | ||
// BW-128: return %[[CAST]] : vector<4x4xf16> | ||
|
||
// BW-0: %[[C1:.*]] = arith.constant 1 : index | ||
// BW-0: %[[C2:.*]] = arith.constant 2 : index | ||
// BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16> | ||
// BW-0: return %[[LOAD]] : vector<4x4xf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is very hard to follow this. Could you follow the pre-existing convention and split DEFAULT
and BW-128
and BW-0
blocks (as opposed to interleaving), Similar comment for other tests.
/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} | ||
/// : vector<4xf32> into vector<16xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} | |
/// : vector<4xf32> into vector<16xf32> | |
/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} | |
/// : vector<4xf32> into vector<16xf32> |
/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into | ||
/// vector<4x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into | |
/// vector<4x4xf32> | |
/// vector.insert_strided_slice %s, %d {offsets=[0, 0]} | |
/// : vector<2x4xf32> into vector<4x4xf32> |
if (srcTy.getRank() != 2) | ||
return rewriter.notifyMatchFailure( | ||
insertOp, | ||
"InsertStridedSliceOp linearization only supports 2D source."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason for supporting only 2D?
So load/store as one PR and splat, create_mask and insert_strided_slice each as independent PR? |
This PR is a breakdown [2 / 4] of the PR #136193 The PR adds linearization patterns for vector.splat.
This PR is a breakdown [2 / 4] of the PR llvm#136193 The PR adds linearization patterns for vector.splat.
This PR is a breakdown [2 / 4] of the PR llvm#136193 The PR adds linearization patterns for vector.splat.
This PR is a breakdown [2 / 4] of the PR llvm#136193 The PR adds linearization patterns for vector.splat.
This PR is a breakdown [2 / 4] of the PR llvm#136193 The PR adds linearization patterns for vector.splat.
This PR is a breakdown [2 / 4] of the PR llvm#136193 The PR adds linearization patterns for vector.splat.
…m#138214) This PR is a breakdown [3 / 4] of the PR llvm#136193 The PR adds linearization patterns for vector.create_mask
This PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors.