@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
1096
1096
SmallVector<OpFoldResult> positionVec = getMixedValues (
1097
1097
adaptor.getStaticPosition (), adaptor.getDynamicPosition (), rewriter);
1098
1098
1099
- // Extract entire vector. Should be handled by folder, but just to be safe.
1100
- ArrayRef<OpFoldResult> position (positionVec);
1101
- if (position.empty ()) {
1102
- rewriter.replaceOp (extractOp, adaptor.getVector ());
1103
- return success ();
1104
- }
1105
-
1106
- // One-shot extraction of vector from array (only requires extractvalue).
1107
- // Except for extracting 1-element vectors.
1108
- if (isa<VectorType>(resultType) &&
1109
- position.size () !=
1110
- static_cast <size_t >(extractOp.getSourceVectorType ().getRank ())) {
1111
- if (extractOp.hasDynamicPosition ())
1112
- return failure ();
1113
-
1114
- Value extracted = rewriter.create <LLVM::ExtractValueOp>(
1115
- loc, adaptor.getVector (), getAsIntegers (position));
1116
- rewriter.replaceOp (extractOp, extracted);
1117
- return success ();
1118
- }
1099
+ // Determine if we need to extract a scalar as the result. We extract
1100
+ // a scalar if the extract is full rank i.e. the number of indices is equal
1101
+ // to source vector rank.
1102
+ bool isScalarExtract =
1103
+ positionVec.size () == extractOp.getSourceVectorType ().getRank ();
1104
+ // Determine if we need to extract a slice out of the original vector. We
1105
+ // always need to extract a slice if the input rank >= 2.
1106
+ bool isSlicingExtract = extractOp.getSourceVectorType ().getRank () >= 2 ;
1119
1107
1120
- // Potential extraction of 1-D vector from array.
1121
1108
Value extracted = adaptor.getVector ();
1122
- if (position.size () > 1 ) {
1123
- if (extractOp.hasDynamicPosition ())
1109
+ if (isSlicingExtract) {
1110
+ ArrayRef<OpFoldResult> position (positionVec);
1111
+ if (isScalarExtract) {
1112
+ // If we are extracting a scalar from the returned slice, we need to
1113
+ // extract a N-1 D slice.
1114
+ position = position.drop_back ();
1115
+ }
1116
+ // llvm.extractvalue does not support dynamic dimensions.
1117
+ if (!llvm::all_of (position,
1118
+ [](OpFoldResult x) { return isa<Attribute>(x); })) {
1124
1119
return failure ();
1120
+ }
1121
+ extracted = rewriter.create <LLVM::ExtractValueOp>(
1122
+ loc, extracted, getAsIntegers (position));
1123
+ }
1125
1124
1126
- SmallVector<int64_t > nMinusOnePosition =
1127
- getAsIntegers (position.drop_back ());
1128
- extracted = rewriter.create <LLVM::ExtractValueOp>(loc, extracted,
1129
- nMinusOnePosition);
1125
+ if (isScalarExtract) {
1126
+ Value position;
1127
+ if (positionVec.empty ()) {
1128
+ // A scalar extract with no position is a 0-D vector extract. The LLVM
1129
+ // type converter converts 0-D vectors to 1-D vectors, so we need to add
1130
+ // a constant position.
1131
+ auto idxType = rewriter.getIndexType ();
1132
+ position = rewriter.create <LLVM::ConstantOp>(
1133
+ loc, typeConverter->convertType (idxType),
1134
+ rewriter.getIntegerAttr (idxType, 0 ));
1135
+ } else {
1136
+ position = getAsLLVMValue (rewriter, loc, positionVec.back ());
1137
+ }
1138
+ extracted =
1139
+ rewriter.create <LLVM::ExtractElementOp>(loc, extracted, position);
1130
1140
}
1131
1141
1132
- Value lastPosition = getAsLLVMValue (rewriter, loc, position.back ());
1133
- // Remaining extraction of element from 1-D LLVM vector.
1134
- rewriter.replaceOpWithNewOp <LLVM::ExtractElementOp>(extractOp, extracted,
1135
- lastPosition);
1142
+ rewriter.replaceOp (extractOp, extracted);
1136
1143
return success ();
1137
1144
}
1138
1145
};
0 commit comments