@@ -1096,43 +1096,55 @@ class VectorExtractOpConversion
1096
1096
SmallVector<OpFoldResult> positionVec = getMixedValues (
1097
1097
adaptor.getStaticPosition (), adaptor.getDynamicPosition (), rewriter);
1098
1098
1099
- // Extract entire vector. Should be handled by folder, but just to be safe.
1100
- ArrayRef<OpFoldResult> position (positionVec);
1101
- if (position.empty ()) {
1102
- rewriter.replaceOp (extractOp, adaptor.getVector ());
1103
- return success ();
1104
- }
1105
-
1106
- // One-shot extraction of vector from array (only requires extractvalue).
1107
- // Except for extracting 1-element vectors.
1108
- if (isa<VectorType>(resultType) &&
1109
- position.size () !=
1110
- static_cast <size_t >(extractOp.getSourceVectorType ().getRank ())) {
1111
- if (extractOp.hasDynamicPosition ())
1112
- return failure ();
1113
-
1114
- Value extracted = rewriter.create <LLVM::ExtractValueOp>(
1115
- loc, adaptor.getVector (), getAsIntegers (position));
1116
- rewriter.replaceOp (extractOp, extracted);
1117
- return success ();
1099
+ // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1100
+ // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1101
+ // from a N-d vector extract to a nested aggregate vector extract in two
1102
+ // steps:
1103
+ // - Extract a member from the nested aggregate. The result can be
1104
+ // a lower rank nested aggregate or a vector (1-D). This is done using
1105
+ // `llvm.extractvalue`.
1106
+ // - Extract a scalar out of the vector if needed. This is done using
1107
+ // `llvm.extractelement`.
1108
+
1109
+ // Determine if we need to extract a member out of the aggregate. We
1110
+ // always need to extract a member if the input rank >= 2.
1111
+ bool extractsAggregate = extractOp.getSourceVectorType ().getRank () >= 2 ;
1112
+ // Determine if we need to extract a scalar as the result. We extract
1113
+ // a scalar if the extract is full rank, i.e., the number of indices is
1114
+ // equal to source vector rank.
1115
+ bool extractsScalar = static_cast <int64_t >(positionVec.size ()) ==
1116
+ extractOp.getSourceVectorType ().getRank ();
1117
+
1118
+ // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1119
+ // need to add a position for this change.
1120
+ if (extractOp.getSourceVectorType ().getRank () == 0 ) {
1121
+ Type idxType = typeConverter->convertType (rewriter.getIndexType ());
1122
+ positionVec.push_back (rewriter.getZeroAttr (idxType));
1118
1123
}
1119
1124
1120
- // Potential extraction of 1-D vector from array.
1121
1125
Value extracted = adaptor.getVector ();
1122
- if (position.size () > 1 ) {
1123
- if (extractOp.hasDynamicPosition ())
1126
+ if (extractsAggregate) {
1127
+ ArrayRef<OpFoldResult> position (positionVec);
1128
+ if (extractsScalar) {
1129
+ // If we are extracting a scalar from the extracted member, we drop
1130
+ // the last index, which will be used to extract the scalar out of the
1131
+ // vector.
1132
+ position = position.drop_back ();
1133
+ }
1134
+ // llvm.extractvalue does not support dynamic dimensions.
1135
+ if (!llvm::all_of (position, llvm::IsaPred<Attribute>)) {
1124
1136
return failure ();
1137
+ }
1138
+ extracted = rewriter.create <LLVM::ExtractValueOp>(
1139
+ loc, extracted, getAsIntegers (position));
1140
+ }
1125
1141
1126
- SmallVector<int64_t > nMinusOnePosition =
1127
- getAsIntegers (position.drop_back ());
1128
- extracted = rewriter.create <LLVM::ExtractValueOp>(loc, extracted,
1129
- nMinusOnePosition);
1142
+ if (extractsScalar) {
1143
+ extracted = rewriter.create <LLVM::ExtractElementOp>(
1144
+ loc, extracted, getAsLLVMValue (rewriter, loc, positionVec.back ()));
1130
1145
}
1131
1146
1132
- Value lastPosition = getAsLLVMValue (rewriter, loc, position.back ());
1133
- // Remaining extraction of element from 1-D LLVM vector.
1134
- rewriter.replaceOpWithNewOp <LLVM::ExtractElementOp>(extractOp, extracted,
1135
- lastPosition);
1147
+ rewriter.replaceOp (extractOp, extracted);
1136
1148
return success ();
1137
1149
}
1138
1150
};
0 commit comments