Skip to content

[mlir][Vector] Remove uses of vector.extractelement/vector.insertelement #113827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

Groverkss
Copy link
Member

This patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information.

@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-backend-amdgpu

Author: Kunwar Grover (Groverkss)

Changes

This patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information.


Patch is 73.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113827.diff

20 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+3-6)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+9-10)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5-5)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+15-8)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+2-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+2-18)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+5-7)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+4-14)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir (+12-12)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+1-2)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+7-7)
  • (modified) mlir/test/Dialect/Linalg/vectorization-scalable.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+4-6)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir (+4-8)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+28-27)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+15-3)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+31-58)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir (+2-4)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+2-2)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0b..6b9cbaf57676c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
     auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
     Value asF16s =
         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
-    Value result = rewriter.create<vector::ExtractElementOp>(
-        loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
     return rewriter.replaceOp(op, result);
   }
   VectorType outType = cast<VectorType>(op.getOut().getType());
@@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
   for (int64_t i = 0; i < numElements; i += 2) {
     int64_t elemsThisOp = std::min(numElements, i + 2) - i;
     Value thisResult = nullptr;
-    Value elemA = rewriter.create<vector::ExtractElementOp>(
-        loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+    Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
     Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
 
     if (elemsThisOp == 2) {
-      elemB = rewriter.create<vector::ExtractElementOp>(
-          loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+      elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
     }
 
     thisResult =
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..ddbc4d2c4a4f3d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
     return Value();
 
   Location loc = xferOp.getLoc();
-  return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+  return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
 }
 
 /// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -686,7 +686,7 @@ struct PrepareTransferWriteConversion
 /// %lastIndex = arith.subi %length, %c1 : index
 /// vector.print punctuation <open>
 /// scf.for %i = %c0 to %length step %c1 {
-///   %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+///   %el = vector.extract %v[%i : index] : vector<[4]xi32>
 ///   vector.print %el : i32 punctuation <no_punctuation>
 ///   %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
 ///   scf.if %notLastIndex {
@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     if (vectorType.getRank() != 1) {
       // Flatten n-D vectors to 1D. This is done to allow indexing with a
       // non-constant value (which can currently only be done via
-      // vector.extractelement for 1D vectors).
+      // vector.extract for 1D vectors).
+      // TODO: vector.extract supports N-D non-constant indices now.
       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
                                         std::multiplies<int64_t>());
       auto flatVectorType =
@@ -819,8 +820,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     }
 
     // Print the scalar elements in the inner most loop.
-    auto element =
-        rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+    auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
     rewriter.create<vector::PrintOp>(loc, element,
                                      vector::PrintPunctuation::NoPunctuation);
 
@@ -1563,7 +1563,7 @@ struct Strategy1d<TransferReadOp> {
         [&](OpBuilder &b, Location loc) {
           Value val =
               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
-          return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+          return b.create<vector::InsertOp>(loc, val, vec, iv);
         },
         /*outOfBoundsCase=*/
         [&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1591,8 +1591,7 @@ struct Strategy1d<TransferWriteOp> {
     generateInBoundsCheck(
         b, xferOp, iv, dim,
         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
-          auto val =
-              b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+          auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
         });
     b.create<scf::YieldOp>(loc);
@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
 /// This pattern generates IR as follows:
 ///
 /// 1. Generate a for loop iterating over each vector element.
-/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
+/// 2. Inside the loop, generate a InsertOp or ExtractOp,
 ///    depending on OpTy.
 ///
 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
@@ -1630,7 +1629,7 @@ struct Strategy1d<TransferWriteOp> {
 /// Is rewritten to approximately the following pseudo-IR:
 /// ```
 /// for i = 0 to 9 {
-///   %t = vector.extractelement %vec[i] : vector<9xf32>
+///   %t = vector.extract %vec[i] : vector<9xf32>
 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
 /// }
 /// ```
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a2457176a1d47..e38bbad1637d45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
   SmallVector<Value> transferReadIdxs;
-  auto zero = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
     Value idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
@@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
                         resultType.getScalableDims().back()),
         idx);
     transferReadIdxs.push_back(
-        rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+        rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
   }
 
   // `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
     // TODO: remove this.
     if (readType.getRank() == 0)
-      readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+      readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                     SmallVector<int64_t>{});
 
     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
                                  << "\n");
@@ -2268,7 +2267,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
       loc, readType, copyOp.getSource(), indices,
       rewriter.getMultiDimIdentityMap(srcType.getRank()));
   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
-    readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+    readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                   SmallVector<int64_t>{});
     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
   }
   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..af5b3637bf5b10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -697,8 +697,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
     Value result;
     if (vectorType.getRank() == 0) {
       if (mask)
-        mask = rewriter.create<ExtractElementOp>(loc, mask);
-      result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+        mask = rewriter.create<ExtractOp>(loc, mask, SmallVector<int64_t>{});
+      result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(),
+                                          SmallVector<int64_t>{});
     } else {
       if (mask)
         mask = rewriter.create<ExtractOp>(loc, mask, 0);
@@ -1983,12 +1984,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     if (extractResultRank < broadcastSrcRank)
       return failure();
 
-    // Special case if broadcast src is a 0D vector.
+    // If extractResultRank is 0, broadcastSrcRank has to be zero, since
+    // broadcastSrcRank >= extractResultRank for this pattern. If so, the input
+    // to the broadcast will be a vector<f32> or f32, but the result will be a
+    // f32, because of vector.extract 0-d semantics. Therefore, we instead
+    // just replace the broadcast with a vector.extract.
     if (extractResultRank == 0) {
       assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, source,
+                                                     SmallVector<int64_t>{});
       return success();
     }
+
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();
@@ -2951,11 +2958,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
               InsertOpConstantFolder>(context);
 }
 
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
-  if (getNumIndices() == 0)
+  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+  // (type mismatch).
+  if (getNumIndices() == 0 && getSourceType() == getResult().getType())
     return getSource();
   return {};
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..6d82d753eeed80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -65,7 +65,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     if (srcRank <= 1 && dstRank == 1) {
       Value ext;
       if (srcRank == 0)
-        ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+        ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(),
+                                                 SmallVector<int64_t>{});
       else
         ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 716da55ba09aec..72bf329daaa76e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
         reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
       }
 
-      result = rewriter.create<vector::InsertElementOp>(
-          loc, reductionOp->getResult(0), result,
-          rewriter.create<arith::ConstantIndexOp>(loc, i));
+      result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
+                                                 result, i);
     }
 
     rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..343178c8156d25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -177,24 +177,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
       }
 
       Value extract;
-      if (srcRank == 0) {
-        // 0-D vector special case
-        assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
-        extract = rewriter.create<vector::ExtractElementOp>(
-            loc, op.getSourceVectorType().getElementType(), op.getSource());
-      } else {
-        extract =
-            rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      }
-
-      if (resRank == 0) {
-        // 0-D vector special case
-        assert(resIdx.empty() && "Unexpected indices for 0-D vector");
-        result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
-      } else {
-        result =
-            rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
-      }
+      extract = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..4ea6bcf3181dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1238,7 +1238,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (extractOp.getNumIndices() == 0)
       return failure();
 
-    // Rewrite vector.extract with 1d source to vector.extractelement.
+    // Rewrite vector.extract with 1d source to vector.extract.
     if (extractSrcType.getRank() == 1) {
       if (extractOp.hasDynamicPosition())
         // TODO: Dinamic position not supported yet.
@@ -1247,9 +1247,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(extractOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = extractOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(extractOp);
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
-          extractOp, extractOp.getVector(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+          extractOp, extractOp.getVector(), pos);
       return success();
     }
 
@@ -1519,9 +1518,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(insertOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = insertOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(insertOp);
-      rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
-          insertOp, insertOp.getSource(), insertOp.getDest(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::InsertOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(), pos);
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..a5d5dc00b33cd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,23 +21,13 @@ using namespace mlir::vector;
 // Helper that picks the proper sequence for inserting.
 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
                        Value into, int64_t offset) {
-  auto vectorType = cast<VectorType>(into.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<InsertOp>(loc, from, into, offset);
-  return rewriter.create<vector::InsertElementOp>(
-      loc, vectorType, from, into,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<InsertOp>(loc, from, into, offset);
 }
 
 // Helper that picks the proper sequence for extracting.
 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
                         int64_t offset) {
-  auto vectorType = cast<VectorType>(vector.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<ExtractOp>(loc, vector, offset);
-  return rewriter.create<vector::ExtractElementOp>(
-      loc, vectorType.getElementType(), vector,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<ExtractOp>(loc, vector, offset);
 }
 
 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
@@ -277,8 +267,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
 };
 
 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
 class DecomposeNDExtractStridedSlice
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
index 121cae26748a82..8991506dee1dfb 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -5,7 +5,7 @@
 func.func @scalar_trunc(%v: f32) -> f16{
   // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
   // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
-  // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+  // CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
   // CHECK: return %[[extract]] : f16
   %w = arith.truncf %v : f32 to f16
   return %w : f16
@@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
 // CHECK-LABEL: @vector_trunc
 // CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
 func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]]
   // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: return %[[ret]]
   %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
@@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
 // CHECK-LABEL:  @vector_trunc_long
 // CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
 func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
   // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
-  // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
-  // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+  // CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
+  // CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
   // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
   // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2024

@llvm/pr-subscribers-mlir-gpu

Author: Kunwar Grover (Groverkss)

Changes

This patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information.


Patch is 73.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113827.diff

20 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+3-6)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+9-10)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5-5)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+15-8)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+2-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+2-18)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+5-7)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+4-14)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir (+12-12)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+1-2)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+7-7)
  • (modified) mlir/test/Dialect/Linalg/vectorization-scalable.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+4-6)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir (+4-8)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+28-27)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+15-3)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+31-58)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir (+2-4)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+2-2)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0b..6b9cbaf57676c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
     auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
     Value asF16s =
         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
-    Value result = rewriter.create<vector::ExtractElementOp>(
-        loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
     return rewriter.replaceOp(op, result);
   }
   VectorType outType = cast<VectorType>(op.getOut().getType());
@@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
   for (int64_t i = 0; i < numElements; i += 2) {
     int64_t elemsThisOp = std::min(numElements, i + 2) - i;
     Value thisResult = nullptr;
-    Value elemA = rewriter.create<vector::ExtractElementOp>(
-        loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+    Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
     Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
 
     if (elemsThisOp == 2) {
-      elemB = rewriter.create<vector::ExtractElementOp>(
-          loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+      elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
     }
 
     thisResult =
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..ddbc4d2c4a4f3d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
     return Value();
 
   Location loc = xferOp.getLoc();
-  return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+  return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
 }
 
 /// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -686,7 +686,7 @@ struct PrepareTransferWriteConversion
 /// %lastIndex = arith.subi %length, %c1 : index
 /// vector.print punctuation <open>
 /// scf.for %i = %c0 to %length step %c1 {
-///   %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+///   %el = vector.extract %v[%i : index] : vector<[4]xi32>
 ///   vector.print %el : i32 punctuation <no_punctuation>
 ///   %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
 ///   scf.if %notLastIndex {
@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     if (vectorType.getRank() != 1) {
       // Flatten n-D vectors to 1D. This is done to allow indexing with a
       // non-constant value (which can currently only be done via
-      // vector.extractelement for 1D vectors).
+      // vector.extract for 1D vectors).
+      // TODO: vector.extract supports N-D non-constant indices now.
       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
                                         std::multiplies<int64_t>());
       auto flatVectorType =
@@ -819,8 +820,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     }
 
     // Print the scalar elements in the inner most loop.
-    auto element =
-        rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+    auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
     rewriter.create<vector::PrintOp>(loc, element,
                                      vector::PrintPunctuation::NoPunctuation);
 
@@ -1563,7 +1563,7 @@ struct Strategy1d<TransferReadOp> {
         [&](OpBuilder &b, Location loc) {
           Value val =
               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
-          return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+          return b.create<vector::InsertOp>(loc, val, vec, iv);
         },
         /*outOfBoundsCase=*/
         [&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1591,8 +1591,7 @@ struct Strategy1d<TransferWriteOp> {
     generateInBoundsCheck(
         b, xferOp, iv, dim,
         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
-          auto val =
-              b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+          auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
         });
     b.create<scf::YieldOp>(loc);
@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
 /// This pattern generates IR as follows:
 ///
 /// 1. Generate a for loop iterating over each vector element.
-/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
+/// 2. Inside the loop, generate a InsertOp or ExtractOp,
 ///    depending on OpTy.
 ///
 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
@@ -1630,7 +1629,7 @@ struct Strategy1d<TransferWriteOp> {
 /// Is rewritten to approximately the following pseudo-IR:
 /// ```
 /// for i = 0 to 9 {
-///   %t = vector.extractelement %vec[i] : vector<9xf32>
+///   %t = vector.extract %vec[i] : vector<9xf32>
 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
 /// }
 /// ```
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a2457176a1d47..e38bbad1637d45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
   SmallVector<Value> transferReadIdxs;
-  auto zero = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
     Value idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
@@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
                         resultType.getScalableDims().back()),
         idx);
     transferReadIdxs.push_back(
-        rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+        rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
   }
 
   // `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
     // TODO: remove this.
     if (readType.getRank() == 0)
-      readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+      readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                     SmallVector<int64_t>{});
 
     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
                                  << "\n");
@@ -2268,7 +2267,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
       loc, readType, copyOp.getSource(), indices,
       rewriter.getMultiDimIdentityMap(srcType.getRank()));
   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
-    readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+    readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                   SmallVector<int64_t>{});
     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
   }
   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..af5b3637bf5b10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -697,8 +697,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
     Value result;
     if (vectorType.getRank() == 0) {
       if (mask)
-        mask = rewriter.create<ExtractElementOp>(loc, mask);
-      result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+        mask = rewriter.create<ExtractOp>(loc, mask, SmallVector<int64_t>{});
+      result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(),
+                                          SmallVector<int64_t>{});
     } else {
       if (mask)
         mask = rewriter.create<ExtractOp>(loc, mask, 0);
@@ -1983,12 +1984,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     if (extractResultRank < broadcastSrcRank)
       return failure();
 
-    // Special case if broadcast src is a 0D vector.
+    // If extractResultRank is 0, broadcastSrcRank has to be zero, since
+    // broadcastSrcRank >= extractResultRank for this pattern. If so, the input
+    // to the broadcast will be a vector<f32> or f32, but the result will be a
+    // f32, because of vector.extract 0-d semantics. Therefore, we instead
+    // just replace the broadcast with a vector.extract.
     if (extractResultRank == 0) {
       assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, source,
+                                                     SmallVector<int64_t>{});
       return success();
     }
+
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();
@@ -2951,11 +2958,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
               InsertOpConstantFolder>(context);
 }
 
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
-  if (getNumIndices() == 0)
+  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+  // (type mismatch).
+  if (getNumIndices() == 0 && getSourceType() == getResult().getType())
     return getSource();
   return {};
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..6d82d753eeed80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -65,7 +65,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     if (srcRank <= 1 && dstRank == 1) {
       Value ext;
       if (srcRank == 0)
-        ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+        ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(),
+                                                 SmallVector<int64_t>{});
       else
         ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 716da55ba09aec..72bf329daaa76e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
         reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
       }
 
-      result = rewriter.create<vector::InsertElementOp>(
-          loc, reductionOp->getResult(0), result,
-          rewriter.create<arith::ConstantIndexOp>(loc, i));
+      result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
+                                                 result, i);
     }
 
     rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..343178c8156d25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -177,24 +177,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
       }
 
       Value extract;
-      if (srcRank == 0) {
-        // 0-D vector special case
-        assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
-        extract = rewriter.create<vector::ExtractElementOp>(
-            loc, op.getSourceVectorType().getElementType(), op.getSource());
-      } else {
-        extract =
-            rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      }
-
-      if (resRank == 0) {
-        // 0-D vector special case
-        assert(resIdx.empty() && "Unexpected indices for 0-D vector");
-        result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
-      } else {
-        result =
-            rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
-      }
+      extract = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..4ea6bcf3181dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1238,7 +1238,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (extractOp.getNumIndices() == 0)
       return failure();
 
-    // Rewrite vector.extract with 1d source to vector.extractelement.
+    // Rewrite vector.extract with 1d source to vector.extract.
     if (extractSrcType.getRank() == 1) {
       if (extractOp.hasDynamicPosition())
         // TODO: Dinamic position not supported yet.
@@ -1247,9 +1247,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(extractOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = extractOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(extractOp);
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
-          extractOp, extractOp.getVector(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+          extractOp, extractOp.getVector(), pos);
       return success();
     }
 
@@ -1519,9 +1518,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(insertOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = insertOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(insertOp);
-      rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
-          insertOp, insertOp.getSource(), insertOp.getDest(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::InsertOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(), pos);
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..a5d5dc00b33cd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,23 +21,13 @@ using namespace mlir::vector;
 // Helper that picks the proper sequence for inserting.
 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
                        Value into, int64_t offset) {
-  auto vectorType = cast<VectorType>(into.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<InsertOp>(loc, from, into, offset);
-  return rewriter.create<vector::InsertElementOp>(
-      loc, vectorType, from, into,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<InsertOp>(loc, from, into, offset);
 }
 
 // Helper that picks the proper sequence for extracting.
 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
                         int64_t offset) {
-  auto vectorType = cast<VectorType>(vector.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<ExtractOp>(loc, vector, offset);
-  return rewriter.create<vector::ExtractElementOp>(
-      loc, vectorType.getElementType(), vector,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<ExtractOp>(loc, vector, offset);
 }
 
 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
@@ -277,8 +267,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
 };
 
 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
 class DecomposeNDExtractStridedSlice
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
index 121cae26748a82..8991506dee1dfb 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -5,7 +5,7 @@
 func.func @scalar_trunc(%v: f32) -> f16{
   // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
   // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
-  // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+  // CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
   // CHECK: return %[[extract]] : f16
   %w = arith.truncf %v : f32 to f16
   return %w : f16
@@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
 // CHECK-LABEL: @vector_trunc
 // CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
 func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]]
   // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: return %[[ret]]
   %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
@@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
 // CHECK-LABEL:  @vector_trunc_long
 // CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
 func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
   // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
-  // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
-  // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+  // CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
+  // CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
   // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
   // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Kunwar Grover (Groverkss)

Changes

This patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information.


Patch is 73.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113827.diff

20 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+3-6)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+9-10)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5-5)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+15-8)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+2-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+2-18)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+5-7)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+4-14)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir (+12-12)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+1-2)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+7-7)
  • (modified) mlir/test/Dialect/Linalg/vectorization-scalable.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+4-6)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir (+4-8)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+28-27)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+15-3)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+31-58)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir (+2-4)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+2-2)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0b..6b9cbaf57676c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
     auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
     Value asF16s =
         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
-    Value result = rewriter.create<vector::ExtractElementOp>(
-        loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
     return rewriter.replaceOp(op, result);
   }
   VectorType outType = cast<VectorType>(op.getOut().getType());
@@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
   for (int64_t i = 0; i < numElements; i += 2) {
     int64_t elemsThisOp = std::min(numElements, i + 2) - i;
     Value thisResult = nullptr;
-    Value elemA = rewriter.create<vector::ExtractElementOp>(
-        loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+    Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
     Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
 
     if (elemsThisOp == 2) {
-      elemB = rewriter.create<vector::ExtractElementOp>(
-          loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+      elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
     }
 
     thisResult =
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..ddbc4d2c4a4f3d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
     return Value();
 
   Location loc = xferOp.getLoc();
-  return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+  return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
 }
 
 /// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -686,7 +686,7 @@ struct PrepareTransferWriteConversion
 /// %lastIndex = arith.subi %length, %c1 : index
 /// vector.print punctuation <open>
 /// scf.for %i = %c0 to %length step %c1 {
-///   %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+///   %el = vector.extract %v[%i : index] : vector<[4]xi32>
 ///   vector.print %el : i32 punctuation <no_punctuation>
 ///   %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
 ///   scf.if %notLastIndex {
@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     if (vectorType.getRank() != 1) {
       // Flatten n-D vectors to 1D. This is done to allow indexing with a
       // non-constant value (which can currently only be done via
-      // vector.extractelement for 1D vectors).
+      // vector.extract for 1D vectors).
+      // TODO: vector.extract supports N-D non-constant indices now.
       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
                                         std::multiplies<int64_t>());
       auto flatVectorType =
@@ -819,8 +820,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     }
 
     // Print the scalar elements in the inner most loop.
-    auto element =
-        rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+    auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
     rewriter.create<vector::PrintOp>(loc, element,
                                      vector::PrintPunctuation::NoPunctuation);
 
@@ -1563,7 +1563,7 @@ struct Strategy1d<TransferReadOp> {
         [&](OpBuilder &b, Location loc) {
           Value val =
               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
-          return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+          return b.create<vector::InsertOp>(loc, val, vec, iv);
         },
         /*outOfBoundsCase=*/
         [&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1591,8 +1591,7 @@ struct Strategy1d<TransferWriteOp> {
     generateInBoundsCheck(
         b, xferOp, iv, dim,
         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
-          auto val =
-              b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+          auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
         });
     b.create<scf::YieldOp>(loc);
@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
 /// This pattern generates IR as follows:
 ///
 /// 1. Generate a for loop iterating over each vector element.
-/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
+/// 2. Inside the loop, generate a InsertOp or ExtractOp,
 ///    depending on OpTy.
 ///
 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
@@ -1630,7 +1629,7 @@ struct Strategy1d<TransferWriteOp> {
 /// Is rewritten to approximately the following pseudo-IR:
 /// ```
 /// for i = 0 to 9 {
-///   %t = vector.extractelement %vec[i] : vector<9xf32>
+///   %t = vector.extract %vec[i] : vector<9xf32>
 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
 /// }
 /// ```
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a2457176a1d47..e38bbad1637d45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
   SmallVector<Value> transferReadIdxs;
-  auto zero = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
     Value idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
@@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
                         resultType.getScalableDims().back()),
         idx);
     transferReadIdxs.push_back(
-        rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+        rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
   }
 
   // `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
     // TODO: remove this.
     if (readType.getRank() == 0)
-      readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+      readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                     SmallVector<int64_t>{});
 
     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
                                  << "\n");
@@ -2268,7 +2267,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
       loc, readType, copyOp.getSource(), indices,
       rewriter.getMultiDimIdentityMap(srcType.getRank()));
   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
-    readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+    readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                   SmallVector<int64_t>{});
     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
   }
   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..af5b3637bf5b10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -697,8 +697,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
     Value result;
     if (vectorType.getRank() == 0) {
       if (mask)
-        mask = rewriter.create<ExtractElementOp>(loc, mask);
-      result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+        mask = rewriter.create<ExtractOp>(loc, mask, SmallVector<int64_t>{});
+      result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(),
+                                          SmallVector<int64_t>{});
     } else {
       if (mask)
         mask = rewriter.create<ExtractOp>(loc, mask, 0);
@@ -1983,12 +1984,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     if (extractResultRank < broadcastSrcRank)
       return failure();
 
-    // Special case if broadcast src is a 0D vector.
+    // If extractResultRank is 0, broadcastSrcRank has to be zero, since
+    // broadcastSrcRank >= extractResultRank for this pattern. If so, the input
+    // to the broadcast will be a vector<f32> or f32, but the result will be a
+    // f32, because of vector.extract 0-d semantics. Therefore, we instead
+    // just replace the broadcast with a vector.extract.
     if (extractResultRank == 0) {
       assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, source,
+                                                     SmallVector<int64_t>{});
       return success();
     }
+
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();
@@ -2951,11 +2958,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
               InsertOpConstantFolder>(context);
 }
 
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
-  if (getNumIndices() == 0)
+  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+  // (type mismatch).
+  if (getNumIndices() == 0 && getSourceType() == getResult().getType())
     return getSource();
   return {};
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..6d82d753eeed80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -65,7 +65,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     if (srcRank <= 1 && dstRank == 1) {
       Value ext;
       if (srcRank == 0)
-        ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+        ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(),
+                                                 SmallVector<int64_t>{});
       else
         ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 716da55ba09aec..72bf329daaa76e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
         reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
       }
 
-      result = rewriter.create<vector::InsertElementOp>(
-          loc, reductionOp->getResult(0), result,
-          rewriter.create<arith::ConstantIndexOp>(loc, i));
+      result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
+                                                 result, i);
     }
 
     rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..343178c8156d25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -177,24 +177,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
       }
 
       Value extract;
-      if (srcRank == 0) {
-        // 0-D vector special case
-        assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
-        extract = rewriter.create<vector::ExtractElementOp>(
-            loc, op.getSourceVectorType().getElementType(), op.getSource());
-      } else {
-        extract =
-            rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      }
-
-      if (resRank == 0) {
-        // 0-D vector special case
-        assert(resIdx.empty() && "Unexpected indices for 0-D vector");
-        result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
-      } else {
-        result =
-            rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
-      }
+      extract = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..4ea6bcf3181dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1238,7 +1238,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (extractOp.getNumIndices() == 0)
       return failure();
 
-    // Rewrite vector.extract with 1d source to vector.extractelement.
+    // Rewrite vector.extract with 1d source to vector.extract.
     if (extractSrcType.getRank() == 1) {
       if (extractOp.hasDynamicPosition())
         // TODO: Dinamic position not supported yet.
@@ -1247,9 +1247,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(extractOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = extractOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(extractOp);
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
-          extractOp, extractOp.getVector(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+          extractOp, extractOp.getVector(), pos);
       return success();
     }
 
@@ -1519,9 +1518,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(insertOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = insertOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(insertOp);
-      rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
-          insertOp, insertOp.getSource(), insertOp.getDest(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::InsertOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(), pos);
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..a5d5dc00b33cd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,23 +21,13 @@ using namespace mlir::vector;
 // Helper that picks the proper sequence for inserting.
 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
                        Value into, int64_t offset) {
-  auto vectorType = cast<VectorType>(into.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<InsertOp>(loc, from, into, offset);
-  return rewriter.create<vector::InsertElementOp>(
-      loc, vectorType, from, into,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<InsertOp>(loc, from, into, offset);
 }
 
 // Helper that picks the proper sequence for extracting.
 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
                         int64_t offset) {
-  auto vectorType = cast<VectorType>(vector.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<ExtractOp>(loc, vector, offset);
-  return rewriter.create<vector::ExtractElementOp>(
-      loc, vectorType.getElementType(), vector,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<ExtractOp>(loc, vector, offset);
 }
 
 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
@@ -277,8 +267,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
 };
 
 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
 class DecomposeNDExtractStridedSlice
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
index 121cae26748a82..8991506dee1dfb 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -5,7 +5,7 @@
 func.func @scalar_trunc(%v: f32) -> f16{
   // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
   // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
-  // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+  // CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
   // CHECK: return %[[extract]] : f16
   %w = arith.truncf %v : f32 to f16
   return %w : f16
@@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
 // CHECK-LABEL: @vector_trunc
 // CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
 func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]]
   // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: return %[[ret]]
   %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
@@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
 // CHECK-LABEL:  @vector_trunc_long
 // CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
 func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
   // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
-  // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
-  // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+  // CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
+  // CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
   // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
   // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into...
[truncated]

@Groverkss
Copy link
Member Author

Depends on #113828

Please only review the top commit until that patch lands

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for the clean-up!

There's a lot of subtle changes here. Would you mind splitting into smaller PRs? For example:

  • vectorization
  • VectorOps
  • etc

?

@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value (which can currently only be done via
// vector.extractelement for 1D vectors).
// vector.extract for 1D vectors).
// TODO: vector.extract supports N-D non-constant indices now.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's either remove this block (preferred) or rephrase the outdated comment - right now it's not clear what the TODO is (missing verb).

Suggestion (but I'd really prefer this being removed instead):

      // Flatten n-D vectors to 1D. This is done for historical reasons and should be removed (using dynamic indices to extract n-D vectors was not supported when this was added).
      // TODO: Remove this block

Or something similar :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fully dynamic n-D vector.inserts/extracts still do not have a lowering.

@@ -233,10 +233,9 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
// CHECK-LABEL: @broadcast_vec2d_from_vec0d(
// CHECK-SAME: %[[A:.*]]: vector<f32>)
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<1xf32> to f32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, why wouldn't this be lowered to llvm.extractelement? And is T5 used at all? In fact, would you mind sharing the output?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much! LGTM % comments from other reviewers

@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
/// This pattern generates IR as follows:
///
/// 1. Generate a for loop iterating over each vector element.
/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
/// 2. Inside the loop, generate a InsertOp or ExtractOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: an

@Groverkss
Copy link
Member Author

Thank you so much for the clean-up!

There's a lot of subtle changes here. Would you mind splitting into smaller PRs? For example:

  • vectorization
  • VectorOps
  • etc

?

I'm not sure if i can split the entire patch into a pr for each file, but i can help splitting it up into easier review pieces. I sent out the first split for trivial changes: #116053

@Groverkss
Copy link
Member Author

Splitting this patch into smaller pieces

@Groverkss Groverkss closed this Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants