Skip to content

[mlir][vector] Add more patterns to Vector Linearize transformation #136193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Apr 17, 2025

This PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors.

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

Changes

This PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors.


Patch is 40.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136193.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+404-3)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+335)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..6de5d0c5a101e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -27,6 +28,10 @@
 using namespace mlir;
 
 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+  // For BW-0, all operations are legal
+  if (targetBitWidth == 0) {
+    return false;
+  }
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
     VectorType vecType = dyn_cast<VectorType>(resType);
@@ -273,6 +278,77 @@ struct LinearizeVectorExtractStridedSlice final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the
+/// source vector using ExtractStridedSliceOp and inserting them into the
+/// destination vector using InsertStridedSliceOp.
+/// Following,
+///   vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+/// is converted to :
+///   %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+///   %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
+///   %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+///   %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+struct LinearizeVectorInsertStridedSlice final
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
+  using OpConversionPattern<
+      vector::InsertStridedSliceOp>::OpConversionPattern;
+      LinearizeVectorInsertStridedSlice(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto srcTy = op.getSourceVectorType();
+    auto dstTy = op.getDestVectorType();
+
+    if (op.hasNonUnitStrides()) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linearization only supports unit strides.");
+    }
+
+    if (srcTy.getRank() != 2) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linearization only supports 2D source.");
+    }
+
+    if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linerization only supports static shapes.");
+    }
+
+    auto dstShape = dstTy.getShape();
+    auto dstStrides = dstShape.drop_front().vec();
+    dstStrides.push_back(1);
+    int64_t linearizedOffset = 0;
+    for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
+      linearizedOffset += getConstantIntValue(off).value() * stride;
+    }
+
+    // extracts a row from source, and insert it into the destination
+    auto srcShape = srcTy.getShape();
+    Value dstValue = adaptor.getDest();
+    for (auto i = 0; i < srcShape[0]; i++) {
+      auto srcOffset = i * srcShape[1];
+      auto value = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1);
+
+      auto dstOffset = linearizedOffset + i * dstShape.back();
+      dstValue = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, value, dstValue, dstOffset, 1);
+    }
+
+    rewriter.replaceOp(op, dstValue);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
 /// vectors to a ShuffleOp that works on linearized vectors.
 /// Following,
@@ -369,6 +445,11 @@ struct LinearizeVectorExtract final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // Skip if result is not a vector type
+    if (!isa<VectorType>(extractOp.getType()))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalar extract is not supported.");
+
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     if (!dstTy)
       return rewriter.notifyMatchFailure(extractOp,
@@ -531,12 +612,312 @@ struct LinearizeVectorBitCast final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %result = arith.constant dense<0.0> : vector<4x4xf32>
+///   %slice_0 = vector.load %base[%indices] : vector<4xf32>
+///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+///   %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+///   ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final
+    : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+  LinearizeVectorLoad(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = loadOp->getLoc();
+    auto vecType = loadOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+    auto unrollCount = shape[0];
+    auto vecSize = shape[1];
+    auto newVecType =
+        VectorType::get({vecSize}, vecType.getElementType());
+
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    // Construct the 2D vector.
+    Value resultVec = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(vecType));
+    // Emit unrolled loads for each 1D vector slice.
+    for (auto i = 0; i < unrollCount; i++) {
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex =
+            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      auto vec = rewriter.create<vector::LoadOp>(
+          loc, newVecType, adaptor.getBase(), indices);
+      resultVec =
+          rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+    }
+
+    rewriter.replaceOp(loadOp, resultVec);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %slice_0 = vector.extract %source[0] : vector<4xf32>
+///   vector.store %slice_0, %base[%indices] : vector<4xf32>
+///   %slice_1 = vector.extract %source[1] : vector<4xf32>
+///   vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+///   ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
+/// slices from the source vector and storing them into the destination.
+/// The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+  LinearizeVectorStore(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = storeOp->getLoc();
+    auto vecType = storeOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+
+    auto unrollCount = shape[0];
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    auto vec = rewriter.create<vector::ShapeCastOp>(
+        loc, vecType, adaptor.getValueToStore());
+
+    for (auto i = 0; i < unrollCount; i++) {
+      auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex =
+            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+                                             indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+///   vector.splat %value : vector<4x4xf32>
+/// is converted to:
+///   %out_1d = vector.splat %value : vector<16xf32>
+///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorSplat(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+    rewriter.replaceOpWithNewOp<vector::SplatOp>(
+        splatOp, adaptor.getInput(), dstTy);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+///   %out_1d = vector.create_mask %dims : vector<4xi1>
+///   %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+        createMaskOp, dstTy, adaptor.getOperands().back());
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
+    : public OpInterfaceConversionPattern<RegionBranchOpInterface> {
+  using OpInterfaceConversionPattern<
+      RegionBranchOpInterface>::OpInterfaceConversionPattern;
+
+  LinearizeRegionBranchOp(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpInterfaceConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(RegionBranchOpInterface op,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto converter = getTypeConverter();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.startOpModification(op);
+
+    llvm::SmallVector<Type> convertedTypes;
+    for (Type ty : op->getResultTypes()) {
+      convertedTypes.push_back(converter->convertType(ty));
+    }
+
+    if (convertedTypes == op->getResultTypes() &&
+        op->getOperands() == operands) {
+      return failure();
+    }
+
+    op->setOperands(operands);
+
+    // Convert region types (block arguments and yields)
+    for (Region &region : op->getRegions()) {
+      if (failed(rewriter.convertRegionTypes(&region, *converter))) {
+        return failure();
+      }
+
+      // Process yields within each region
+      for (Block &block : region) {
+        if (auto *terminator = block.getTerminator()) {
+          for (OpOperand &yieldOperand : terminator->getOpOperands()) {
+            Value value = yieldOperand.get();
+            Type type = value.getType();
+            if (!converter->isLegal(type)) {
+              Type newTy = converter->convertType(type);
+              rewriter.setInsertionPoint(terminator);
+              Value newValue =
+                  rewriter.create<vector::ShapeCastOp>(loc, newTy, value);
+              yieldOperand.set(newValue);
+            }
+          }
+        }
+      }
+    }
+
+    // Update result types
+    rewriter.setInsertionPointAfter(op);
+    llvm::SmallVector<Value> newResults;
+    for (Value result : op->getResults()) {
+      Type oldTy = result.getType();
+      if (!converter->isLegal(oldTy)) {
+        Type newTy = converter->convertType(oldTy);
+        result.setType(newTy);
+        Operation *castOp =
+            rewriter.create<vector::ShapeCastOp>(loc, oldTy, result);
+        result.replaceAllUsesExcept(castOp->getResult(0), castOp);
+        newResults.push_back(castOp->getResult(0));
+      } else {
+        newResults.push_back(result);
+      }
+    }
+
+    rewriter.finalizeOpModification(op);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth) {
 
+  typeConverter.addConversion([](Type type) -> Type { return type; });
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
       return type;
@@ -555,9 +936,12 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   };
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+  target.addLegalOp<mlir::vector::ShapeCastOp>();
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
+        if ((isa<vector::BitCastOp, vector::LoadOp,
+                 vector::StoreOp, vector::CreateMaskOp,
+                 RegionBranchOpInterface, vector::SplatOp>(op) ||
              op->hasTrait<OpTrait::ConstantLike>() ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,7 +952,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       });
 
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+               LinearizeVectorBitCast, LinearizeVectorLoad,
+               LinearizeVectorStore, LinearizeVectorSplat,
+               LinearizeVectorCreateMask, LinearizeRegionBranchOp
+               >(typeConverter, patterns.getContext(),
                                        targetBitWidth);
 }
 
@@ -583,7 +970,21 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
                               .getRank() == 1)
                    : true;
       });
+
+  target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+    [=](vector::InsertStridedSliceOp op) -> bool {
+      if(isLessThanTargetBitWidth(op, targetBitWidth)) {
+        auto srcTy = op.getSourceVectorType();
+        auto dstTy = op.getDestVectorType();
+        if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+            srcTy.hasStaticShape() && dstTy.hasStaticShape())
+          return false;
+      }
+      return true;
+    });
+
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
-               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
+               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+               LinearizeVectorInsertStridedSlice>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..e47e7c4a84d68 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,338 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
   return %1 : vector<[4]x4xf16>
 }
+
+// -----
+// ALL-LABEL: test_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+  // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+  // BW-128: %[[C1:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+  // BW-128: %[[C2:.*]] = arith.constant 2 : index
+...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors.


Patch is 40.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136193.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+404-3)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+335)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..6de5d0c5a101e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -27,6 +28,10 @@
 using namespace mlir;
 
 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+  // For BW-0, all operations are legal
+  if (targetBitWidth == 0) {
+    return false;
+  }
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
     VectorType vecType = dyn_cast<VectorType>(resType);
@@ -273,6 +278,77 @@ struct LinearizeVectorExtractStridedSlice final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the
+/// source vector using ExtractStridedSliceOp and inserting them into the
+/// destination vector using InsertStridedSliceOp.
+/// Following,
+///   vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+/// is converted to :
+///   %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+///   %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
+///   %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+///   %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+struct LinearizeVectorInsertStridedSlice final
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
+  using OpConversionPattern<
+      vector::InsertStridedSliceOp>::OpConversionPattern;
+      LinearizeVectorInsertStridedSlice(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto srcTy = op.getSourceVectorType();
+    auto dstTy = op.getDestVectorType();
+
+    if (op.hasNonUnitStrides()) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linearization only supports unit strides.");
+    }
+
+    if (srcTy.getRank() != 2) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linearization only supports 2D source.");
+    }
+
+    if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linerization only supports static shapes.");
+    }
+
+    auto dstShape = dstTy.getShape();
+    auto dstStrides = dstShape.drop_front().vec();
+    dstStrides.push_back(1);
+    int64_t linearizedOffset = 0;
+    for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
+      linearizedOffset += getConstantIntValue(off).value() * stride;
+    }
+
+    // extracts a row from source, and insert it into the destination
+    auto srcShape = srcTy.getShape();
+    Value dstValue = adaptor.getDest();
+    for (auto i = 0; i < srcShape[0]; i++) {
+      auto srcOffset = i * srcShape[1];
+      auto value = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1);
+
+      auto dstOffset = linearizedOffset + i * dstShape.back();
+      dstValue = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, value, dstValue, dstOffset, 1);
+    }
+
+    rewriter.replaceOp(op, dstValue);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
 /// vectors to a ShuffleOp that works on linearized vectors.
 /// Following,
@@ -369,6 +445,11 @@ struct LinearizeVectorExtract final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // Skip if result is not a vector type
+    if (!isa<VectorType>(extractOp.getType()))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalar extract is not supported.");
+
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     if (!dstTy)
       return rewriter.notifyMatchFailure(extractOp,
@@ -531,12 +612,312 @@ struct LinearizeVectorBitCast final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %result = arith.constant dense<0.0> : vector<4x4xf32>
+///   %slice_0 = vector.load %base[%indices] : vector<4xf32>
+///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+///   %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+///   ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final
+    : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+  LinearizeVectorLoad(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = loadOp->getLoc();
+    auto vecType = loadOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+    auto unrollCount = shape[0];
+    auto vecSize = shape[1];
+    auto newVecType =
+        VectorType::get({vecSize}, vecType.getElementType());
+
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    // Construct the 2D vector.
+    Value resultVec = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(vecType));
+    // Emit unrolled loads for each 1D vector slice.
+    for (auto i = 0; i < unrollCount; i++) {
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex =
+            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      auto vec = rewriter.create<vector::LoadOp>(
+          loc, newVecType, adaptor.getBase(), indices);
+      resultVec =
+          rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+    }
+
+    rewriter.replaceOp(loadOp, resultVec);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %slice_0 = vector.extract %source[0] : vector<4xf32>
+///   vector.store %slice_0, %base[%indices] : vector<4xf32>
+///   %slice_1 = vector.extract %source[1] : vector<4xf32>
+///   vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+///   ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
+/// slices from the source vector and storing them into the destination.
+/// The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+  LinearizeVectorStore(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = storeOp->getLoc();
+    auto vecType = storeOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+
+    auto unrollCount = shape[0];
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    auto vec = rewriter.create<vector::ShapeCastOp>(
+        loc, vecType, adaptor.getValueToStore());
+
+    for (auto i = 0; i < unrollCount; i++) {
+      auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex =
+            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+                                             indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+///   vector.splat %value : vector<4x4xf32>
+/// is converted to:
+///   %out_1d = vector.splat %value : vector<16xf32>
+///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorSplat(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+    rewriter.replaceOpWithNewOp<vector::SplatOp>(
+        splatOp, adaptor.getInput(), dstTy);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+///   %out_1d = vector.create_mask %dims : vector<4xi1>
+///   %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+        createMaskOp, dstTy, adaptor.getOperands().back());
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
+    : public OpInterfaceConversionPattern<RegionBranchOpInterface> {
+  using OpInterfaceConversionPattern<
+      RegionBranchOpInterface>::OpInterfaceConversionPattern;
+
+  LinearizeRegionBranchOp(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpInterfaceConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(RegionBranchOpInterface op,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto converter = getTypeConverter();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.startOpModification(op);
+
+    llvm::SmallVector<Type> convertedTypes;
+    for (Type ty : op->getResultTypes()) {
+      convertedTypes.push_back(converter->convertType(ty));
+    }
+
+    if (convertedTypes == op->getResultTypes() &&
+        op->getOperands() == operands) {
+      return failure();
+    }
+
+    op->setOperands(operands);
+
+    // Convert region types (block arguments and yields)
+    for (Region &region : op->getRegions()) {
+      if (failed(rewriter.convertRegionTypes(&region, *converter))) {
+        return failure();
+      }
+
+      // Process yields within each region
+      for (Block &block : region) {
+        if (auto *terminator = block.getTerminator()) {
+          for (OpOperand &yieldOperand : terminator->getOpOperands()) {
+            Value value = yieldOperand.get();
+            Type type = value.getType();
+            if (!converter->isLegal(type)) {
+              Type newTy = converter->convertType(type);
+              rewriter.setInsertionPoint(terminator);
+              Value newValue =
+                  rewriter.create<vector::ShapeCastOp>(loc, newTy, value);
+              yieldOperand.set(newValue);
+            }
+          }
+        }
+      }
+    }
+
+    // Update result types
+    rewriter.setInsertionPointAfter(op);
+    llvm::SmallVector<Value> newResults;
+    for (Value result : op->getResults()) {
+      Type oldTy = result.getType();
+      if (!converter->isLegal(oldTy)) {
+        Type newTy = converter->convertType(oldTy);
+        result.setType(newTy);
+        Operation *castOp =
+            rewriter.create<vector::ShapeCastOp>(loc, oldTy, result);
+        result.replaceAllUsesExcept(castOp->getResult(0), castOp);
+        newResults.push_back(castOp->getResult(0));
+      } else {
+        newResults.push_back(result);
+      }
+    }
+
+    rewriter.finalizeOpModification(op);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth) {
 
+  typeConverter.addConversion([](Type type) -> Type { return type; });
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
       return type;
@@ -555,9 +936,12 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   };
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+  target.addLegalOp<mlir::vector::ShapeCastOp>();
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
+        if ((isa<vector::BitCastOp, vector::LoadOp,
+                 vector::StoreOp, vector::CreateMaskOp,
+                 RegionBranchOpInterface, vector::SplatOp>(op) ||
              op->hasTrait<OpTrait::ConstantLike>() ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,7 +952,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       });
 
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+               LinearizeVectorBitCast, LinearizeVectorLoad,
+               LinearizeVectorStore, LinearizeVectorSplat,
+               LinearizeVectorCreateMask, LinearizeRegionBranchOp
+               >(typeConverter, patterns.getContext(),
                                        targetBitWidth);
 }
 
@@ -583,7 +970,21 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
                               .getRank() == 1)
                    : true;
       });
+
+  target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+    [=](vector::InsertStridedSliceOp op) -> bool {
+      if(isLessThanTargetBitWidth(op, targetBitWidth)) {
+        auto srcTy = op.getSourceVectorType();
+        auto dstTy = op.getDestVectorType();
+        if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+            srcTy.hasStaticShape() && dstTy.hasStaticShape())
+          return false;
+      }
+      return true;
+    });
+
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
-               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
+               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+               LinearizeVectorInsertStridedSlice>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..e47e7c4a84d68 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,338 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
   return %1 : vector<[4]x4xf16>
 }
+
+// -----
+// ALL-LABEL: test_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+  // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+  // BW-128: %[[C1:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+  // BW-128: %[[C2:.*]] = arith.constant 2 : index
+...
[truncated]

Copy link

github-actions bot commented Apr 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@nbpatel
Copy link
Contributor Author

nbpatel commented Apr 17, 2025

@Hardcode84 @charithaintc @chencha3 Please take a look as well

@nbpatel nbpatel changed the title [mlr][vector] Add more patterns to Vector Linearize transformation [mlir][vector] Add more patterns to Vector Linearize transformation Apr 17, 2025
Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few drive by comments. I'm no expert on this, so please ignore my suggestions where not appropriate

@@ -27,6 +27,10 @@
using namespace mlir;

static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
// For BW-0, all operations are legal
if (targetBitWidth == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

targetBitWidth = std::numeric_limits<unsigned>::max() is used in places, please consolidate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies I wasn't clear in this comment. I meant to consolidate use of targetBitWidth = 0 and targetBitWidth = 'max'. Is this a workaround for adding 1 to std::numeric_limits<unsigned>::max() ?

I would like to commit #136581 which would mean this logic doesn't live here anymore and this comment wouldn't be relevant, could you please take a look at that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I briefly looked at it. I'm ok with that change, but can we commit this first if possible?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have reached a point where I need to land this to make progress with other development, I hope that's ok @nbpatel

It should be easy to absorb the changes (just remove code related to bitwidth)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes go ahead and merge it

/// vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} :
/// vector<4xf32> from vector<8xf32> %3 = vector.insert_strided_slice %2, %1
/// {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
struct LinearizeVectorInsertStridedSlice final
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it's possible to reuse

to convert to shuffle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think since this is not a pass and just a bunch of patterns, users can decide how they want to lower from this point

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like it if this were a collection of patterns that the user could opt in or out of one-by-one, but currently there are only 2 APIS exposed to the user (populateVectorLinearizeShuffleLikeOpsPatterns and populateVectorLinearizeTypeConversionsAndLegality)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the users can run populateVectorLinearizeTypeConversionsAndLegality and populateVectorInsertExtractStridedSliceTransforms to convert to shuffle

// Skip if result is not a vector type
if (!isa<VectorType>(extractOp.getType()))
return rewriter.notifyMatchFailure(extractOp,
"scalar extract is not supported.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"scalar extract is not supported.");
"scalar extract is not supported, because ...");

might to helpful!

@@ -531,12 +618,239 @@ struct LinearizeVectorBitCast final
unsigned targetVectorBitWidth;
};

/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting.

/// Following,
/// vector.load %base[%indices] : vector<4x4xf32>
/// is converted to :
/// %result = arith.constant dense<0.0> : vector<4x4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be beneficial to flatten/linearize out all contiguous dimensions first. i.e. if the load is actually unstrided, like

%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8x100xf32>

if this flattened to

%result = vector.load %flat_base[%i] : memref<10000xf32>, vector<800xf32>

the IR generated wouldn't be unrolled.

It seems to me like this is more unrolling than linearizing?

void populateVectorUnrollPatterns(RewritePatternSet &patterns,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a more general solution for unstrided/strided memrefs and we can always fuse the loads later on as an optimization

@@ -27,6 +27,10 @@
using namespace mlir;

static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
// For BW-0, all operations are legal
if (targetBitWidth == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies I wasn't clear in this comment. I meant to consolidate use of targetBitWidth = 0 and targetBitWidth = 'max'. Is this a workaround for adding 1 to std::numeric_limits<unsigned>::max() ?

I would like to commit #136581 which would mean this logic doesn't live here anymore and this comment wouldn't be relevant, could you please take a look at that?

/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into
/// vector<4x4xf32>
/// is converted to :
/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check formatting. Maybe clang-format off and clang-format on will be helpful?

@newling
Copy link
Contributor

newling commented Apr 23, 2025

Hi @nbpatel I think these patterns will all be useful for users of the vector dialect, but I have a few requests that I would like to make first:

  1. while unrolling is a kind of linearization in spirit, there is what seems to be me a better place to add the patterns for load/store unrolling:

    void mlir::vector::populateVectorUnrollPatterns(
    . I think if we keep the linearization patterns as strictly "shape_casts" -> "single op in rank 1" -> "shape_casts" it will be easier to understand the code base

  2. Regarding the unrolling pattern here for insert_strided_slice, it's not clear why this isn't rather a conversion to shuffle (like extract_strided_slice is) in which case how is it different to

    ? And if unrolling is required, I think it should migrate to the unrolling patterns.

  3. Please put the splat and create_mask patterns into separate PRs

I know they're large requests, my apologies for not suggesting them earlier in the process. I think your contributions will be significantly more useful if done this way.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

This is adding 4 unrelated patterns - could you split this into independent PRs?

Comment on lines +650 to +652
auto loc = loadOp->getLoc();
auto vecType = loadOp.getVectorType();
auto shape = vecType.getShape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you spell out the types here and other places? For reference, here are LLVM's guidelines for using auto:

// -----
// ALL-LABEL: test_vector_load
// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid using test in test function names - that's unnecessary noise. I appreciate that you are trying to follow the existing convention in this file, but IMO we should focus on encoding unique information. Here are our guidelines:

So, what makes this test unique?

Comment on lines +407 to +449
// DEFAULT: %[[C1:.*]] = arith.constant 1 : index
// BW-128: %[[C1:.*]] = arith.constant 1 : index
// DEFAULT: %[[C2:.*]] = arith.constant 2 : index
// BW-128: %[[C2:.*]] = arith.constant 2 : index
// DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
// BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
// DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
// BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
// DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
// BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
// DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index
// BW-128: %[[C1_0:.*]] = arith.constant 1 : index
// DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
// BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
// DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
// BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
// DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
// BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
// DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index
// BW-128: %[[C2_1:.*]] = arith.constant 2 : index
// DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
// BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
// DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
// BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
// DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
// BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
// DEFAULT: %[[C3:.*]] = arith.constant 3 : index
// BW-128: %[[C3:.*]] = arith.constant 3 : index
// DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
// BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
// DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
// BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
// DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
// BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
// DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
// BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
// DEFAULT: return %[[CAST]] : vector<4x4xf16>
// BW-128: return %[[CAST]] : vector<4x4xf16>

// BW-0: %[[C1:.*]] = arith.constant 1 : index
// BW-0: %[[C2:.*]] = arith.constant 2 : index
// BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16>
// BW-0: return %[[LOAD]] : vector<4x4xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very hard to follow this. Could you follow the pre-existing convention and split DEFAULT and BW-128 and BW-0 blocks (as opposed to interleaving), Similar comment for other tests.

Comment on lines +293 to +294
/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]}
/// : vector<4xf32> into vector<16xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]}
/// : vector<4xf32> into vector<16xf32>
/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]}
/// : vector<4xf32> into vector<16xf32>

Comment on lines +288 to +289
/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into
/// vector<4x4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into
/// vector<4x4xf32>
/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}
/// : vector<2x4xf32> into vector<4x4xf32>

if (srcTy.getRank() != 2)
return rewriter.notifyMatchFailure(
insertOp,
"InsertStridedSliceOp linearization only supports 2D source.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for supporting only 2D?

@nbpatel
Copy link
Contributor Author

nbpatel commented Apr 23, 2025

Thanks!

This is adding 4 unrelated patterns - could you split this into independent PRs?

So load/store as one PR and splat, create_mask and insert_strided_slice each as independent PR?

@nbpatel nbpatel closed this Apr 27, 2025
newling pushed a commit that referenced this pull request May 1, 2025
This PR is a breakdown [2 / 4] of the PR #136193 
The PR adds linearization patterns for vector.splat.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This PR is a breakdown [2 / 4] of the PR llvm#136193 
The PR adds linearization patterns for vector.splat.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This PR is a breakdown [2 / 4] of the PR llvm#136193 
The PR adds linearization patterns for vector.splat.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This PR is a breakdown [2 / 4] of the PR llvm#136193 
The PR adds linearization patterns for vector.splat.
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
This PR is a breakdown [2 / 4] of the PR llvm#136193 
The PR adds linearization patterns for vector.splat.
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
This PR is a breakdown [2 / 4] of the PR llvm#136193 
The PR adds linearization patterns for vector.splat.
newling pushed a commit that referenced this pull request May 14, 2025
…8214)

This PR is a breakdown [3 / 4] of the PR #136193 
The PR adds linearization patterns for vector.create_mask
TIFitis pushed a commit to TIFitis/llvm-project that referenced this pull request May 19, 2025
…m#138214)

This PR is a breakdown [3 / 4] of the PR llvm#136193 
The PR adds linearization patterns for vector.create_mask
@nbpatel nbpatel deleted the vector_linearize branch June 10, 2025 22:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants