-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Remove uses of vector.extractelement/vector.insertelement #113827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-backend-amdgpu Author: Kunwar Grover (Groverkss) ChangesThis patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information. Patch is 73.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113827.diff 20 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0b..6b9cbaf57676c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
Value asF16s =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
- Value result = rewriter.create<vector::ExtractElementOp>(
- loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+ Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
return rewriter.replaceOp(op, result);
}
VectorType outType = cast<VectorType>(op.getOut().getType());
@@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
- Value elemA = rewriter.create<vector::ExtractElementOp>(
- loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+ Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
if (elemsThisOp == 2) {
- elemB = rewriter.create<vector::ExtractElementOp>(
- loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+ elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
}
thisResult =
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..ddbc4d2c4a4f3d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();
Location loc = xferOp.getLoc();
- return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+ return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
}
/// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -686,7 +686,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i : index] : vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value (which can currently only be done via
- // vector.extractelement for 1D vectors).
+ // vector.extract for 1D vectors).
+ // TODO: vector.extract supports N-D non-constant indices now.
auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
auto flatVectorType =
@@ -819,8 +820,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
}
// Print the scalar elements in the inner most loop.
- auto element =
- rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+ auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
rewriter.create<vector::PrintOp>(loc, element,
vector::PrintPunctuation::NoPunctuation);
@@ -1563,7 +1563,7 @@ struct Strategy1d<TransferReadOp> {
[&](OpBuilder &b, Location loc) {
Value val =
b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
- return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+ return b.create<vector::InsertOp>(loc, val, vec, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1591,8 +1591,7 @@ struct Strategy1d<TransferWriteOp> {
generateInBoundsCheck(
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
- auto val =
- b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+ auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
});
b.create<scf::YieldOp>(loc);
@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
/// This pattern generates IR as follows:
///
/// 1. Generate a for loop iterating over each vector element.
-/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
+/// 2. Inside the loop, generate a InsertOp or ExtractOp,
/// depending on OpTy.
///
/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
@@ -1630,7 +1629,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a2457176a1d47..e38bbad1637d45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
// (0th) element and use that.
SmallVector<Value> transferReadIdxs;
- auto zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
Value idx = bvm.lookup(extractOp.getIndices()[i]);
if (idx.getType().isIndex()) {
@@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
resultType.getScalableDims().back()),
idx);
transferReadIdxs.push_back(
- rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+ rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
}
// `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
if (readType.getRank() == 0)
- readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+ readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+ SmallVector<int64_t>{});
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
<< "\n");
@@ -2268,7 +2267,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
- readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+ readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+ SmallVector<int64_t>{});
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..af5b3637bf5b10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -697,8 +697,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
Value result;
if (vectorType.getRank() == 0) {
if (mask)
- mask = rewriter.create<ExtractElementOp>(loc, mask);
- result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+ mask = rewriter.create<ExtractOp>(loc, mask, SmallVector<int64_t>{});
+ result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(),
+ SmallVector<int64_t>{});
} else {
if (mask)
mask = rewriter.create<ExtractOp>(loc, mask, 0);
@@ -1983,12 +1984,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
if (extractResultRank < broadcastSrcRank)
return failure();
- // Special case if broadcast src is a 0D vector.
+ // If extractResultRank is 0, broadcastSrcRank has to be zero, since
+ // broadcastSrcRank >= extractResultRank for this pattern. If so, the input
+ // to the broadcast will be a vector<f32> or f32, but the result will be a
+ // f32, because of vector.extract 0-d semantics. Therefore, we instead
+ // just replace the broadcast with a vector.extract.
if (extractResultRank == 0) {
assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, source,
+ SmallVector<int64_t>{});
return success();
}
+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
return success();
@@ -2951,11 +2958,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
InsertOpConstantFolder>(context);
}
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
- if (getNumIndices() == 0)
+ // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+ // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+ // (type mismatch).
+ if (getNumIndices() == 0 && getSourceType() == getResult().getType())
return getSource();
return {};
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..6d82d753eeed80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -65,7 +65,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
if (srcRank <= 1 && dstRank == 1) {
Value ext;
if (srcRank == 0)
- ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+ ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(),
+ SmallVector<int64_t>{});
else
ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 716da55ba09aec..72bf329daaa76e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
}
- result = rewriter.create<vector::InsertElementOp>(
- loc, reductionOp->getResult(0), result,
- rewriter.create<arith::ConstantIndexOp>(loc, i));
+ result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
+ result, i);
}
rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..343178c8156d25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -177,24 +177,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
}
Value extract;
- if (srcRank == 0) {
- // 0-D vector special case
- assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
- extract = rewriter.create<vector::ExtractElementOp>(
- loc, op.getSourceVectorType().getElementType(), op.getSource());
- } else {
- extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- }
-
- if (resRank == 0) {
- // 0-D vector special case
- assert(resIdx.empty() && "Unexpected indices for 0-D vector");
- result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
- } else {
- result =
- rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
- }
+ extract = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+ result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..4ea6bcf3181dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1238,7 +1238,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (extractOp.getNumIndices() == 0)
return failure();
- // Rewrite vector.extract with 1d source to vector.extractelement.
+ // Rewrite vector.extract with 1d source to vector.extract.
if (extractSrcType.getRank() == 1) {
if (extractOp.hasDynamicPosition())
// TODO: Dinamic position not supported yet.
@@ -1247,9 +1247,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(extractOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = extractOp.getStaticPosition()[0];
rewriter.setInsertionPoint(extractOp);
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
- extractOp, extractOp.getVector(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, extractOp.getVector(), pos);
return success();
}
@@ -1519,9 +1518,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(insertOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = insertOp.getStaticPosition()[0];
rewriter.setInsertionPoint(insertOp);
- rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
- insertOp, insertOp.getSource(), insertOp.getDest(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ insertOp, insertOp.getSource(), insertOp.getDest(), pos);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..a5d5dc00b33cd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,23 +21,13 @@ using namespace mlir::vector;
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
- auto vectorType = cast<VectorType>(into.getType());
- if (vectorType.getRank() > 1)
- return rewriter.create<InsertOp>(loc, from, into, offset);
- return rewriter.create<vector::InsertElementOp>(
- loc, vectorType, from, into,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ return rewriter.create<InsertOp>(loc, from, into, offset);
}
// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
- auto vectorType = cast<VectorType>(vector.getType());
- if (vectorType.getRank() > 1)
- return rewriter.create<ExtractOp>(loc, vector, offset);
- return rewriter.create<vector::ExtractElementOp>(
- loc, vectorType.getElementType(), vector,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ return rewriter.create<ExtractOp>(loc, vector, offset);
}
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
@@ -277,8 +267,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
};
/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
index 121cae26748a82..8991506dee1dfb 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -5,7 +5,7 @@
func.func @scalar_trunc(%v: f32) -> f16{
// CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
// CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
- // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+ // CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
// CHECK: return %[[extract]] : f16
%w = arith.truncf %v : f32 to f16
return %w : f16
@@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
// CHECK-LABEL: @vector_trunc
// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
- // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
- // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[elem0:.*]] = vector.extract %[[value]]
+ // CHECK: %[[elem1:.*]] = vector.extract %[[value]]
// CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: return %[[ret]]
%w = arith.truncf %v : vector<2xf32> to vector<2xf16>
@@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
// CHECK-LABEL: @vector_trunc_long
// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
- // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
- // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+ // CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
+ // CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
// CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
- // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
- // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+ // CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
+ // CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
// CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
// CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into...
[truncated]
|
@llvm/pr-subscribers-mlir-gpu Author: Kunwar Grover (Groverkss) ChangesThis patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information. Patch is 73.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113827.diff 20 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0b..6b9cbaf57676c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
Value asF16s =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
- Value result = rewriter.create<vector::ExtractElementOp>(
- loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+ Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
return rewriter.replaceOp(op, result);
}
VectorType outType = cast<VectorType>(op.getOut().getType());
@@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
- Value elemA = rewriter.create<vector::ExtractElementOp>(
- loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+ Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
if (elemsThisOp == 2) {
- elemB = rewriter.create<vector::ExtractElementOp>(
- loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+ elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
}
thisResult =
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..ddbc4d2c4a4f3d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();
Location loc = xferOp.getLoc();
- return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+ return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
}
/// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -686,7 +686,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i : index] : vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value (which can currently only be done via
- // vector.extractelement for 1D vectors).
+ // vector.extract for 1D vectors).
+ // TODO: vector.extract supports N-D non-constant indices now.
auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
auto flatVectorType =
@@ -819,8 +820,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
}
// Print the scalar elements in the inner most loop.
- auto element =
- rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+ auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
rewriter.create<vector::PrintOp>(loc, element,
vector::PrintPunctuation::NoPunctuation);
@@ -1563,7 +1563,7 @@ struct Strategy1d<TransferReadOp> {
[&](OpBuilder &b, Location loc) {
Value val =
b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
- return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+ return b.create<vector::InsertOp>(loc, val, vec, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1591,8 +1591,7 @@ struct Strategy1d<TransferWriteOp> {
generateInBoundsCheck(
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
- auto val =
- b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+ auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
});
b.create<scf::YieldOp>(loc);
@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
/// This pattern generates IR as follows:
///
/// 1. Generate a for loop iterating over each vector element.
-/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
+/// 2. Inside the loop, generate a InsertOp or ExtractOp,
/// depending on OpTy.
///
/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
@@ -1630,7 +1629,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a2457176a1d47..e38bbad1637d45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
// (0th) element and use that.
SmallVector<Value> transferReadIdxs;
- auto zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
Value idx = bvm.lookup(extractOp.getIndices()[i]);
if (idx.getType().isIndex()) {
@@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
resultType.getScalableDims().back()),
idx);
transferReadIdxs.push_back(
- rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+ rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
}
// `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
if (readType.getRank() == 0)
- readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+ readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+ SmallVector<int64_t>{});
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
<< "\n");
@@ -2268,7 +2267,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
- readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+ readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+ SmallVector<int64_t>{});
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..af5b3637bf5b10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -697,8 +697,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
Value result;
if (vectorType.getRank() == 0) {
if (mask)
- mask = rewriter.create<ExtractElementOp>(loc, mask);
- result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+ mask = rewriter.create<ExtractOp>(loc, mask, SmallVector<int64_t>{});
+ result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(),
+ SmallVector<int64_t>{});
} else {
if (mask)
mask = rewriter.create<ExtractOp>(loc, mask, 0);
@@ -1983,12 +1984,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
if (extractResultRank < broadcastSrcRank)
return failure();
- // Special case if broadcast src is a 0D vector.
+ // If extractResultRank is 0, broadcastSrcRank has to be zero, since
+ // broadcastSrcRank >= extractResultRank for this pattern. If so, the input
+ // to the broadcast will be a vector<f32> or f32, but the result will be a
+ // f32, because of vector.extract 0-d semantics. Therefore, we instead
+ // just replace the broadcast with a vector.extract.
if (extractResultRank == 0) {
assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, source,
+ SmallVector<int64_t>{});
return success();
}
+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
return success();
@@ -2951,11 +2958,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
InsertOpConstantFolder>(context);
}
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
- if (getNumIndices() == 0)
+ // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+ // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+ // (type mismatch).
+ if (getNumIndices() == 0 && getSourceType() == getResult().getType())
return getSource();
return {};
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..6d82d753eeed80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -65,7 +65,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
if (srcRank <= 1 && dstRank == 1) {
Value ext;
if (srcRank == 0)
- ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+ ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(),
+ SmallVector<int64_t>{});
else
ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 716da55ba09aec..72bf329daaa76e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
}
- result = rewriter.create<vector::InsertElementOp>(
- loc, reductionOp->getResult(0), result,
- rewriter.create<arith::ConstantIndexOp>(loc, i));
+ result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
+ result, i);
}
rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..343178c8156d25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -177,24 +177,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
}
Value extract;
- if (srcRank == 0) {
- // 0-D vector special case
- assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
- extract = rewriter.create<vector::ExtractElementOp>(
- loc, op.getSourceVectorType().getElementType(), op.getSource());
- } else {
- extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- }
-
- if (resRank == 0) {
- // 0-D vector special case
- assert(resIdx.empty() && "Unexpected indices for 0-D vector");
- result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
- } else {
- result =
- rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
- }
+ extract = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+ result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..4ea6bcf3181dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1238,7 +1238,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (extractOp.getNumIndices() == 0)
return failure();
- // Rewrite vector.extract with 1d source to vector.extractelement.
+ // Rewrite vector.extract with 1d source to vector.extract.
if (extractSrcType.getRank() == 1) {
if (extractOp.hasDynamicPosition())
// TODO: Dinamic position not supported yet.
@@ -1247,9 +1247,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(extractOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = extractOp.getStaticPosition()[0];
rewriter.setInsertionPoint(extractOp);
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
- extractOp, extractOp.getVector(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, extractOp.getVector(), pos);
return success();
}
@@ -1519,9 +1518,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(insertOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = insertOp.getStaticPosition()[0];
rewriter.setInsertionPoint(insertOp);
- rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
- insertOp, insertOp.getSource(), insertOp.getDest(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ insertOp, insertOp.getSource(), insertOp.getDest(), pos);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..a5d5dc00b33cd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,23 +21,13 @@ using namespace mlir::vector;
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
- auto vectorType = cast<VectorType>(into.getType());
- if (vectorType.getRank() > 1)
- return rewriter.create<InsertOp>(loc, from, into, offset);
- return rewriter.create<vector::InsertElementOp>(
- loc, vectorType, from, into,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ return rewriter.create<InsertOp>(loc, from, into, offset);
}
// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
- auto vectorType = cast<VectorType>(vector.getType());
- if (vectorType.getRank() > 1)
- return rewriter.create<ExtractOp>(loc, vector, offset);
- return rewriter.create<vector::ExtractElementOp>(
- loc, vectorType.getElementType(), vector,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ return rewriter.create<ExtractOp>(loc, vector, offset);
}
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
@@ -277,8 +267,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
};
/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
index 121cae26748a82..8991506dee1dfb 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -5,7 +5,7 @@
func.func @scalar_trunc(%v: f32) -> f16{
// CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
// CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
- // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+ // CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
// CHECK: return %[[extract]] : f16
%w = arith.truncf %v : f32 to f16
return %w : f16
@@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
// CHECK-LABEL: @vector_trunc
// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
- // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
- // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[elem0:.*]] = vector.extract %[[value]]
+ // CHECK: %[[elem1:.*]] = vector.extract %[[value]]
// CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: return %[[ret]]
%w = arith.truncf %v : vector<2xf32> to vector<2xf16>
@@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
// CHECK-LABEL: @vector_trunc_long
// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
- // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
- // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+ // CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
+ // CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
// CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
- // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
- // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+ // CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
+ // CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
// CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
// CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Kunwar Grover (Groverkss) ChangesThis patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information. Patch is 73.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113827.diff 20 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0b..6b9cbaf57676c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
Value asF16s =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
- Value result = rewriter.create<vector::ExtractElementOp>(
- loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+ Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
return rewriter.replaceOp(op, result);
}
VectorType outType = cast<VectorType>(op.getOut().getType());
@@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
- Value elemA = rewriter.create<vector::ExtractElementOp>(
- loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+ Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
if (elemsThisOp == 2) {
- elemB = rewriter.create<vector::ExtractElementOp>(
- loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+ elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
}
thisResult =
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..ddbc4d2c4a4f3d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();
Location loc = xferOp.getLoc();
- return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+ return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
}
/// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -686,7 +686,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i : index] : vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value (which can currently only be done via
- // vector.extractelement for 1D vectors).
+ // vector.extract for 1D vectors).
+ // TODO: vector.extract supports N-D non-constant indices now.
auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
auto flatVectorType =
@@ -819,8 +820,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
}
// Print the scalar elements in the inner most loop.
- auto element =
- rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+ auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
rewriter.create<vector::PrintOp>(loc, element,
vector::PrintPunctuation::NoPunctuation);
@@ -1563,7 +1563,7 @@ struct Strategy1d<TransferReadOp> {
[&](OpBuilder &b, Location loc) {
Value val =
b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
- return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+ return b.create<vector::InsertOp>(loc, val, vec, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1591,8 +1591,7 @@ struct Strategy1d<TransferWriteOp> {
generateInBoundsCheck(
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
- auto val =
- b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+ auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
});
b.create<scf::YieldOp>(loc);
@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
/// This pattern generates IR as follows:
///
/// 1. Generate a for loop iterating over each vector element.
-/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
+/// 2. Inside the loop, generate a InsertOp or ExtractOp,
/// depending on OpTy.
///
/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
@@ -1630,7 +1629,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a2457176a1d47..e38bbad1637d45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
// (0th) element and use that.
SmallVector<Value> transferReadIdxs;
- auto zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
Value idx = bvm.lookup(extractOp.getIndices()[i]);
if (idx.getType().isIndex()) {
@@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
resultType.getScalableDims().back()),
idx);
transferReadIdxs.push_back(
- rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+ rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
}
// `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
if (readType.getRank() == 0)
- readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+ readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+ SmallVector<int64_t>{});
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
<< "\n");
@@ -2268,7 +2267,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
- readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+ readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+ SmallVector<int64_t>{});
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..af5b3637bf5b10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -697,8 +697,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
Value result;
if (vectorType.getRank() == 0) {
if (mask)
- mask = rewriter.create<ExtractElementOp>(loc, mask);
- result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+ mask = rewriter.create<ExtractOp>(loc, mask, SmallVector<int64_t>{});
+ result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(),
+ SmallVector<int64_t>{});
} else {
if (mask)
mask = rewriter.create<ExtractOp>(loc, mask, 0);
@@ -1983,12 +1984,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
if (extractResultRank < broadcastSrcRank)
return failure();
- // Special case if broadcast src is a 0D vector.
+ // If extractResultRank is 0, broadcastSrcRank has to be zero, since
+ // broadcastSrcRank >= extractResultRank for this pattern. If so, the input
+ // to the broadcast will be a vector<f32> or f32, but the result will be a
+ // f32, because of vector.extract 0-d semantics. Therefore, we instead
+ // just replace the broadcast with a vector.extract.
if (extractResultRank == 0) {
assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, source,
+ SmallVector<int64_t>{});
return success();
}
+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
return success();
@@ -2951,11 +2958,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
InsertOpConstantFolder>(context);
}
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
- if (getNumIndices() == 0)
+ // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+ // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+ // (type mismatch).
+ if (getNumIndices() == 0 && getSourceType() == getResult().getType())
return getSource();
return {};
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..6d82d753eeed80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -65,7 +65,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
if (srcRank <= 1 && dstRank == 1) {
Value ext;
if (srcRank == 0)
- ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+ ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(),
+ SmallVector<int64_t>{});
else
ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 716da55ba09aec..72bf329daaa76e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
}
- result = rewriter.create<vector::InsertElementOp>(
- loc, reductionOp->getResult(0), result,
- rewriter.create<arith::ConstantIndexOp>(loc, i));
+ result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
+ result, i);
}
rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..343178c8156d25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -177,24 +177,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
}
Value extract;
- if (srcRank == 0) {
- // 0-D vector special case
- assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
- extract = rewriter.create<vector::ExtractElementOp>(
- loc, op.getSourceVectorType().getElementType(), op.getSource());
- } else {
- extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- }
-
- if (resRank == 0) {
- // 0-D vector special case
- assert(resIdx.empty() && "Unexpected indices for 0-D vector");
- result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
- } else {
- result =
- rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
- }
+ extract = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+ result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..4ea6bcf3181dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1238,7 +1238,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (extractOp.getNumIndices() == 0)
return failure();
- // Rewrite vector.extract with 1d source to vector.extractelement.
+ // Rewrite vector.extract with 1d source to vector.extract.
if (extractSrcType.getRank() == 1) {
if (extractOp.hasDynamicPosition())
// TODO: Dinamic position not supported yet.
@@ -1247,9 +1247,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(extractOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = extractOp.getStaticPosition()[0];
rewriter.setInsertionPoint(extractOp);
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
- extractOp, extractOp.getVector(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, extractOp.getVector(), pos);
return success();
}
@@ -1519,9 +1518,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(insertOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = insertOp.getStaticPosition()[0];
rewriter.setInsertionPoint(insertOp);
- rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
- insertOp, insertOp.getSource(), insertOp.getDest(),
- rewriter.create<arith::ConstantIndexOp>(loc, pos));
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ insertOp, insertOp.getSource(), insertOp.getDest(), pos);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..a5d5dc00b33cd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,23 +21,13 @@ using namespace mlir::vector;
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
- auto vectorType = cast<VectorType>(into.getType());
- if (vectorType.getRank() > 1)
- return rewriter.create<InsertOp>(loc, from, into, offset);
- return rewriter.create<vector::InsertElementOp>(
- loc, vectorType, from, into,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ return rewriter.create<InsertOp>(loc, from, into, offset);
}
// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
- auto vectorType = cast<VectorType>(vector.getType());
- if (vectorType.getRank() > 1)
- return rewriter.create<ExtractOp>(loc, vector, offset);
- return rewriter.create<vector::ExtractElementOp>(
- loc, vectorType.getElementType(), vector,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ return rewriter.create<ExtractOp>(loc, vector, offset);
}
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
@@ -277,8 +267,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
};
/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
index 121cae26748a82..8991506dee1dfb 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -5,7 +5,7 @@
func.func @scalar_trunc(%v: f32) -> f16{
// CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
// CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
- // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+ // CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
// CHECK: return %[[extract]] : f16
%w = arith.truncf %v : f32 to f16
return %w : f16
@@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
// CHECK-LABEL: @vector_trunc
// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
- // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
- // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+ // CHECK: %[[elem0:.*]] = vector.extract %[[value]]
+ // CHECK: %[[elem1:.*]] = vector.extract %[[value]]
// CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: return %[[ret]]
%w = arith.truncf %v : vector<2xf32> to vector<2xf16>
@@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
// CHECK-LABEL: @vector_trunc_long
// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
- // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
- // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+ // CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
+ // CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
// CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
- // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
- // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+ // CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
+ // CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
// CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
// CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into...
[truncated]
|
Depends on #113828 Please only review the top commit until that patch lands |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for the clean-up!
There's a lot of subtle changes here. Would you mind splitting into smaller PRs? For example:
- vectorization
- VectorOps
- etc
?
@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { | |||
if (vectorType.getRank() != 1) { | |||
// Flatten n-D vectors to 1D. This is done to allow indexing with a | |||
// non-constant value (which can currently only be done via | |||
// vector.extractelement for 1D vectors). | |||
// vector.extract for 1D vectors). | |||
// TODO: vector.extract supports N-D non-constant indices now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's either remove this block (preferred) or rephrase the outdated comment - right now it's not clear what the TODO is (missing verb).
Suggestion (but I'd really prefer this being removed instead):
// Flatten n-D vectors to 1D. This is done for historical reasons and should be removed (using dynamic indices to extract n-D vectors was not supported when this was added).
// TODO: Remove this block
Or something similar :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fully dynamic n-D vector.inserts/extracts still do not have a lowering.
@@ -233,10 +233,9 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> { | |||
// CHECK-LABEL: @broadcast_vec2d_from_vec0d( | |||
// CHECK-SAME: %[[A:.*]]: vector<f32>) | |||
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32> | |||
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<1xf32> to f32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, why wouldn't this be lowered to llvm.extractelement
? And is T5
used at all? In fact, would you mind sharing the output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much! LGTM % comments from other reviewers
@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> { | |||
/// This pattern generates IR as follows: | |||
/// | |||
/// 1. Generate a for loop iterating over each vector element. | |||
/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, | |||
/// 2. Inside the loop, generate a InsertOp or ExtractOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: an
I'm not sure if i can split the entire patch into a pr for each file, but i can help splitting it up into easier review pieces. I sent out the first split for trivial changes: #116053 |
Splitting this patch into smaller pieces |
This patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information.