Skip to content

[mlir][Vector] Remove uses of vector.extractelement/vector.insertelement #113827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
Value asF16s =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
Value result = rewriter.create<vector::ExtractElementOp>(
loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
return rewriter.replaceOp(op, result);
}
VectorType outType = cast<VectorType>(op.getOut().getType());
Expand All @@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
Value elemA = rewriter.create<vector::ExtractElementOp>(
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());

if (elemsThisOp == 2) {
elemB = rewriter.create<vector::ExtractElementOp>(
loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
}

thisResult =
Expand Down
19 changes: 9 additions & 10 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
return Value();

Location loc = xferOp.getLoc();
return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
}

/// Helper function TransferOpConversion and TransferOp1dConversion.
Expand Down Expand Up @@ -686,7 +686,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
/// %el = vector.extract %v[%i : index] : vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
Expand Down Expand Up @@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value (which can currently only be done via
// vector.extractelement for 1D vectors).
// vector.extract for 1D vectors).
// TODO: vector.extract supports N-D non-constant indices now.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's either remove this block (preferred) or rephrase the outdated comment - right now it's not clear what the TODO is (missing verb).

Suggestion (but I'd really prefer this being removed instead):

      // Flatten n-D vectors to 1D. This is done for historical reasons and should be removed (using dynamic indices to extract n-D vectors was not supported when this was added).
      // TODO: Remove this block

Or something similar :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fully dynamic n-D vector.inserts/extracts still do not have a lowering.

auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
auto flatVectorType =
Expand Down Expand Up @@ -819,8 +820,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
}

// Print the scalar elements in the inner most loop.
auto element =
rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
rewriter.create<vector::PrintOp>(loc, element,
vector::PrintPunctuation::NoPunctuation);

Expand Down Expand Up @@ -1563,7 +1563,7 @@ struct Strategy1d<TransferReadOp> {
[&](OpBuilder &b, Location loc) {
Value val =
b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
return b.create<vector::InsertElementOp>(loc, val, vec, iv);
return b.create<vector::InsertOp>(loc, val, vec, iv);
},
/*outOfBoundsCase=*/
[&](OpBuilder & /*b*/, Location loc) { return vec; });
Expand Down Expand Up @@ -1591,8 +1591,7 @@ struct Strategy1d<TransferWriteOp> {
generateInBoundsCheck(
b, xferOp, iv, dim,
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
auto val =
b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
});
b.create<scf::YieldOp>(loc);
Expand All @@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
/// This pattern generates IR as follows:
///
/// 1. Generate a for loop iterating over each vector element.
/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
/// 2. Inside the loop, generate a InsertOp or ExtractOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: an

/// depending on OpTy.
///
/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
Expand All @@ -1630,7 +1629,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
/// %t = vector.extractelement %vec[i] : vector<9xf32>
/// %t = vector.extract %vec[i] : vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
// (0th) element and use that.
SmallVector<Value> transferReadIdxs;
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
Value idx = bvm.lookup(extractOp.getIndices()[i]);
if (idx.getType().isIndex()) {
Expand All @@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
resultType.getScalableDims().back()),
idx);
transferReadIdxs.push_back(
rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
}

// `tensor.extract_element` is always in-bounds, hence the following holds.
Expand Down Expand Up @@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
if (readType.getRank() == 0)
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
SmallVector<int64_t>{});

LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
<< "\n");
Expand Down Expand Up @@ -2268,7 +2267,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
SmallVector<int64_t>{});
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
Expand Down
23 changes: 15 additions & 8 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
Value result;
if (vectorType.getRank() == 0) {
if (mask)
mask = rewriter.create<ExtractElementOp>(loc, mask);
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
mask = rewriter.create<ExtractOp>(loc, mask, SmallVector<int64_t>{});
result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(),
SmallVector<int64_t>{});
} else {
if (mask)
mask = rewriter.create<ExtractOp>(loc, mask, 0);
Expand Down Expand Up @@ -1983,12 +1984,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
if (extractResultRank < broadcastSrcRank)
return failure();

// Special case if broadcast src is a 0D vector.
// If extractResultRank is 0, broadcastSrcRank has to be zero, since
// broadcastSrcRank >= extractResultRank for this pattern. If so, the input
// to the broadcast will be a vector<f32> or f32, but the result will be a
// f32, because of vector.extract 0-d semantics. Therefore, we instead
// just replace the broadcast with a vector.extract.
if (extractResultRank == 0) {
assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, source,
SmallVector<int64_t>{});
return success();
}

rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
return success();
Expand Down Expand Up @@ -2951,11 +2958,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
InsertOpConstantFolder>(context);
}

// Eliminates insert operations that produce values identical to their source
// value. This happens when the source and destination vectors have identical
// sizes.
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
if (getNumIndices() == 0)
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getResult().getType())
return getSource();
return {};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
if (srcRank <= 1 && dstRank == 1) {
Value ext;
if (srcRank == 0)
ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(),
SmallVector<int64_t>{});
else
ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
}

result = rewriter.create<vector::InsertElementOp>(
loc, reductionOp->getResult(0), result,
rewriter.create<arith::ConstantIndexOp>(loc, i));
result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
result, i);
}

rewriter.replaceOp(rootOp, result);
Expand Down
20 changes: 2 additions & 18 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
}

Value extract;
if (srcRank == 0) {
// 0-D vector special case
assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
extract = rewriter.create<vector::ExtractElementOp>(
loc, op.getSourceVectorType().getElementType(), op.getSource());
} else {
extract =
rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
}

if (resRank == 0) {
// 0-D vector special case
assert(resIdx.empty() && "Unexpected indices for 0-D vector");
result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
} else {
result =
rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
extract = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
rewriter.replaceOp(op, result);
return success();
Expand Down
12 changes: 5 additions & 7 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (extractOp.getNumIndices() == 0)
return failure();

// Rewrite vector.extract with 1d source to vector.extractelement.
// Rewrite vector.extract with 1d source to vector.extract.
if (extractSrcType.getRank() == 1) {
if (extractOp.hasDynamicPosition())
// TODO: Dinamic position not supported yet.
Expand All @@ -1247,9 +1247,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(extractOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = extractOp.getStaticPosition()[0];
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
extractOp, extractOp.getVector(),
rewriter.create<arith::ConstantIndexOp>(loc, pos));
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
extractOp, extractOp.getVector(), pos);
return success();
}

Expand Down Expand Up @@ -1519,9 +1518,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(insertOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = insertOp.getStaticPosition()[0];
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
insertOp, insertOp.getSource(), insertOp.getDest(),
rewriter.create<arith::ConstantIndexOp>(loc, pos));
rewriter.replaceOpWithNewOp<vector::InsertOp>(
insertOp, insertOp.getSource(), insertOp.getDest(), pos);
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,13 @@ using namespace mlir::vector;
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
auto vectorType = cast<VectorType>(into.getType());
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
loc, vectorType, from, into,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
return rewriter.create<InsertOp>(loc, from, into, offset);
}

// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
auto vectorType = cast<VectorType>(vector.getType());
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
loc, vectorType.getElementType(), vector,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
return rewriter.create<ExtractOp>(loc, vector, offset);
}

/// RewritePattern for InsertStridedSliceOp where source and destination vectors
Expand Down Expand Up @@ -277,8 +267,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
};

/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
/// For such cases, we can rewrite it to ExtractOp + lower rank
/// ExtractStridedSliceOp + InsertOp for the n-D case.
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
Expand Down
24 changes: 12 additions & 12 deletions mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
func.func @scalar_trunc(%v: f32) -> f16{
// CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
// CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
// CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
// CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
// CHECK: return %[[extract]] : f16
%w = arith.truncf %v : f32 to f16
return %w : f16
Expand All @@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
// CHECK-LABEL: @vector_trunc
// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem0:.*]] = vector.extract %[[value]]
// CHECK: %[[elem1:.*]] = vector.extract %[[value]]
// CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: return %[[ret]]
%w = arith.truncf %v : vector<2xf32> to vector<2xf16>
Expand All @@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
// CHECK-LABEL: @vector_trunc_long
// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
// CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
// CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
// CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
// CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
// CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
// CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
// CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
// CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
// CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
// CHECK: %[[elem4:.*]] = vector.extract %[[value]][4]
// CHECK: %[[elem5:.*]] = vector.extract %[[value]][5]
// CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
// CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem6:.*]] = vector.extract %[[value]]
// CHECK: %[[elem7:.*]] = vector.extract %[[value]]
// CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
// CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem8:.*]] = vector.extract %[[value]]
// CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
// CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
// CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
Expand Down
3 changes: 1 addition & 2 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,9 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
// CHECK-LABEL: @broadcast_vec2d_from_vec0d(
// CHECK-SAME: %[[A:.*]]: vector<f32>)
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<1xf32> to f32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, why wouldn't this be lowered to llvm.extractelement? And is T5 used at all? In fact, would you mind sharing the output?

// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
Expand Down
Loading
Loading